How to calculate moving median in DataFrame?

1.5k Views Asked by At

Is there a way to calculate moving median for an attribute in Spark DataFrame?

I was hoping that it is possible to calculate moving median using a window function (by defining a window using rowsBetween(0,10)), but there no functionality to calculate it (similar to average or mean).

3

There are 3 best solutions below

0
On

In Spark 2.1+, to find median we can use functions percentile and percentile_approx. We can use them both in aggregations and with window functions. As you originally wanted, you can use rowsBetween() too.

Examples using PySpark:

from pyspark.sql import SparkSession, functions as F, Window as W
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
    [(1, 10),
     (1, 20),
     (1, 30),
     (1, 40),
     (1, 50),
     (2, 50)],
    ['c1', 'c2']
)
df = (
    df
    .withColumn(
        'moving_median_1',
        F.expr('percentile(c2, 0.5)').over(W.partitionBy('c1').orderBy('c2')))
    .withColumn(
        'moving_median_2',
        F.expr('percentile(c2, 0.5) over(partition by c1 order by c2)'))
    .withColumn(
        'moving_median_3_rows_1',
        F.expr('percentile(c2, 0.5)').over(W.partitionBy('c1').orderBy('c2').rowsBetween(-2, 0)))
    .withColumn(
        'moving_median_3_rows_2',
        F.expr('percentile(c2, 0.5) over(partition by c1 order by c2 rows between 2 preceding and current row)'))
).show()
#+---+---+---------------+---------------+----------------------+----------------------+
#| c1| c2|moving_median_1|moving_median_2|moving_median_3_rows_1|moving_median_3_rows_2|
#+---+---+---------------+---------------+----------------------+----------------------+
#|  1| 10|           10.0|           10.0|                  10.0|                  10.0|
#|  1| 20|           15.0|           15.0|                  15.0|                  15.0|
#|  1| 30|           20.0|           20.0|                  20.0|                  20.0|
#|  1| 40|           25.0|           25.0|                  30.0|                  30.0|
#|  1| 50|           30.0|           30.0|                  40.0|                  40.0|
#|  2| 50|           50.0|           50.0|                  50.0|                  50.0|
#+---+---+---------------+---------------+----------------------+----------------------+
0
On

I think you've got few options here.

ntile window function

I think ntile(2) (over a window of rows) would give you two "segments" that in turn you could use to calculate the median over the window.

Quoting the scaladoc:

ntile(n: Int) Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window partition. For example, if n is 4, the first quarter of the rows will get value 1, the second quarter will get 2, the third quarter will get 3, and the last quarter will get 4.

This is equivalent to the NTILE function in SQL.

If the number of rows in one group is bigger than in the other, pick the largest from the bigger group.

If the number of rows in the groups is even, take the maximum and the minimum in each group and calculate the median.

I found it quite nicely described in Calculating median using the NTILE function.

percent_rank window function

I think percent_rank might also be an option to calculate the median over a window of rows.

Quoting the scaladoc:

percent_rank() Window function: returns the relative rank (i.e. percentile) of rows within a window partition.

This is computed by:

(rank of row in its partition - 1) / (number of rows in the partition - 1)

This is equivalent to the PERCENT_RANK function in SQL.

User-Defined Aggregate Function (UDAF)

You could write a user-defined aggregate function (UDAF) to calculate median over a window.

A UDAF extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction which is (quoting the scaladoc):

The base class for implementing user-defined aggregate functions (UDAF).

Luckily there is an sample implementation of a custom UDAF in UserDefinedUntypedAggregation example.

2
On

Here is the class I extended UserDefinedAggregateFunction to get moving median.

class MyMedian extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType =
    org.apache.spark.sql.types.StructType(org.apache.spark.sql.types.StructField("value", org.apache.spark.sql.types.DoubleType) :: Nil)

  def bufferSchema: org.apache.spark.sql.types.StructType = org.apache.spark.sql.types.StructType(
    org.apache.spark.sql.types.StructField("window_list", org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.DoubleType, false)) :: Nil
  )
  def dataType: org.apache.spark.sql.types.DataType = org.apache.spark.sql.types.DoubleType
  def deterministic: Boolean = true
  def initialize(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer): Unit = {
    buffer(0) = new scala.collection.mutable.ArrayBuffer[Double]()
  }
  def update(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer,input: org.apache.spark.sql.Row): Unit = {
    var bufferVal=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).toBuffer
    bufferVal+=input.getAs[Double](0)
    buffer(0) = bufferVal
  }
  def merge(buffer1: org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2: org.apache.spark.sql.Row): Unit = {
    buffer1(0) = buffer1.getAs[scala.collection.mutable.ArrayBuffer[Double]](0) ++ buffer2.getAs[scala.collection.mutable.ArrayBuffer[Double]](0)
  }
  def evaluate(buffer: org.apache.spark.sql.Row): Any = {
      var sortedWindow=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).sorted.toBuffer
      var windowSize=sortedWindow.size
      if(windowSize%2==0){
          var index=windowSize/2
          (sortedWindow(index) + sortedWindow(index-1))/2
      }else{
          var index=(windowSize+1)/2 - 1
          sortedWindow(index)
      }
  }
}

using above UDAF example:

// Create an instance of UDAF MyMedian.
val mm = new MyMedian

var movingMedianDS = dataSet.withColumn("MovingMedian", mm(col("value")).over( Window.partitionBy("GroupId").rowsBetween(-10,10)) )