Compute rolling percentiles in PySpark

482 Views Asked by At

I have a dataframe with dates, ID (let's say of a city) and two columns of temperatures (in my real dataframe I have a dozen of columns to compute).

I want to "rank" those temperatures for a given window. I want this ranking to be scaled from 0 (the lowest temperature of the window) to 100 (the highest temperature for the same window). The window must be symmetrical (meaning I take in consideration as many days before then days after). My test dataframe looks like this:

+-----------+-------+-----------------+-----------------+
|DATE_TICKET|ID_SITE|MAX_TEMPERATURE_C|MIN_TEMPERATURE_C|
+-----------+-------+-----------------+-----------------+
| 2017-03-24|    001|               22|               10|
| 2017-03-25|    001|               25|               15|
| 2017-03-26|    001|               31|               19|
| 2017-03-27|    001|               29|               18|
| 2017-03-28|    001|               30|               16|
| 2017-03-29|    001|               25|               17|
| 2017-03-30|    001|               24|               16|
| 2017-03-24|    002|               18|               12|
| 2017-03-25|    002|               17|               11|
| 2017-03-27|    002|               15|                7|
| 2017-03-28|    002|               12|                5|
| 2017-03-29|    002|                8|                3|
| 2017-03-30|    002|               10|                1|
| 2017-03-31|    002|               15|                4|
| 2017-03-24|    003|               18|                7|
| 2017-03-26|    003|               22|               11|
| 2017-03-27|    003|               27|               12|
| 2017-03-28|    003|               29|               15|
| 2017-04-01|    003|               31|               16|
| 2017-04-04|    003|               34|               22|
+-----------+-------+-----------------+-----------------+

To recreate my data, you can use this code:

data = {'DATE_TICKET': ['2017-03-24','2017-03-25','2017-03-26','2017-03-27','2017-03-28','2017-03-29','2017-03-30',
                        '2017-03-24','2017-03-25','2017-03-27','2017-03-28','2017-03-29','2017-03-30','2017-03-31',
                        '2017-03-24','2017-03-26','2017-03-27','2017-03-28','2017-04-01','2017-04-04'],
        'ID_SITE': ['001','001','001','001','001','001','001','002','002','002','002','002','002','002','003','003','003','003','003','003'],
        'MAX_TEMPERATURE_C': [22,25,31,29,30,25,24,18,17,15,12,8,10,15,18,22,27,29,31,34],
        'MIN_TEMPERATURE_C': [10,15,19,18,16,17,16,12,11,7,5,3,1,4,7,11,12,15,16,22]}
df = pd.DataFrame(data)
ddf = ctx.createDataFrame(df)
ddf = ddf.withColumn('DATE_TICKET', ddf['DATE_TICKET'].cast('date'))

At the moment my code looks like this:

import pandas as pd
import pyspark
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.window import Window
from pyspark.sql.types import FloatType

window_size = 2
target =  int((window_size)-0.5)

w = Window.partitionBy("ID_SITE").orderBy("DATE_TICKET").rowsBetween(-(window_size), window_size)

median_udf = F.udf(lambda x: float(np.median(x)), FloatType())

rank_udf = F.udf(lambda x: pd.cut(x, 101, include_lowest=True, labels=list(range(0,101)))[target])

ddf.withColumn("list", F.collect_list("MAX_TEMPERATURE_C").over(w)) \
  .withColumn("rolling_median", median_udf("list")).show(truncate = False)

This works with the 'median_udf' function (that I pasted from another post on Stack Overflow by the way). But this function doesn't do what I expect.

I want to use the rank_udf function which works fine when I apply it on a single list. It ranks all the values for a given window and returns a single value, the one in the middle, which I'm interested in.

For instance:

data = [22,25,31,29,31,34,26,21]
target =  int((len(data)/2)-0.5)
pd.cut(data, 101, include_lowest=True, labels=list(range(0,101)))[target]

But:

  • First of all, it returns an error when I use it as a UDF in PySpark.
  • Even if there was no error, I'm using a Pandas function and I want to be able to do it without using Pandas library because I'm working on hundreds of millions of lines and I need performance.

I tried using functions like Bucketizer or QuantileDiscretizer from pyspark.ml.feature, but I can't manage to make them work.

(P.S.: Yes, I do know it's not really percentile because I'm using 101 bins instead of 100)

(P.P.S.: I will edit this post if you need more context/information)

1

There are 1 best solutions below

0
On

A couple of years late, but here's a solution without pandas or numpy UDFs.

This is a great brain teaser. I first wanted to avoid the "list" column and do everything in the window frame. But I couldn't, as after first ordering by date, another ordering would be required to calculate percentiles... So I kept the "list" column and used several array functions.

array_max (Spark 2.4)
array_min (Spark 2.4)
element_at (Spark 2.4)
size (Spark 1.5)

min_max_diff = F.array_max("list") - F.array_min("list")
target = F.floor((F.size("list") + 1) / 2).cast('int')
ddf = ddf.withColumn("list", F.collect_list("MAX_TEMPERATURE_C").over(w)) \
         .withColumn(
             "percentile_of_target",
             F.ceil(100 / min_max_diff * (F.element_at("list", target) - F.array_min("list")))
         )
ddf.show()
#+-----------+-------+-----------------+-----------------+--------------------+--------------------+
#|DATE_TICKET|ID_SITE|MAX_TEMPERATURE_C|MIN_TEMPERATURE_C|                list|percentile_of_target|
#+-----------+-------+-----------------+-----------------+--------------------+--------------------+
#| 2017-03-24|    001|               22|               10|        [22, 25, 31]|                  34|
#| 2017-03-25|    001|               25|               15|    [22, 25, 31, 29]|                  34|
#| 2017-03-26|    001|               31|               19|[22, 25, 31, 29, 30]|                 100|
#| 2017-03-27|    001|               29|               18|[25, 31, 29, 30, 25]|                  67|
#| 2017-03-28|    001|               30|               16|[31, 29, 30, 25, 24]|                  86|
#| 2017-03-29|    001|               25|               17|    [29, 30, 25, 24]|                 100|
#| 2017-03-30|    001|               24|               16|        [30, 25, 24]|                  17|
#| 2017-03-24|    002|               18|               12|        [18, 17, 15]|                  67|
#| 2017-03-25|    002|               17|               11|    [18, 17, 15, 12]|                  84|
#| 2017-03-27|    002|               15|                7| [18, 17, 15, 12, 8]|                  70|
#| 2017-03-28|    002|               12|                5| [17, 15, 12, 8, 10]|                  45|
#| 2017-03-29|    002|                8|                3| [15, 12, 8, 10, 15]|                   0|
#| 2017-03-30|    002|               10|                1|     [12, 8, 10, 15]|                   0|
#| 2017-03-31|    002|               15|                4|         [8, 10, 15]|                  29|
#| 2017-03-24|    003|               18|                7|        [18, 22, 27]|                  45|
#| 2017-03-26|    003|               22|               11|    [18, 22, 27, 29]|                  37|
#| 2017-03-27|    003|               27|               12|[18, 22, 27, 29, 31]|                  70|
#| 2017-03-28|    003|               29|               15|[22, 27, 29, 31, 34]|                  59|
#| 2017-04-01|    003|               31|               16|    [27, 29, 31, 34]|                  29|
#| 2017-04-04|    003|               34|               22|        [29, 31, 34]|                  40|
#+-----------+-------+-----------------+-----------------+--------------------+--------------------+