How to run user defined function over a window in spark dataframe?

325 Views Asked by At

I am trying to detect the outliers from my spark dataframe. Below is the data sample.

pressure Timestamp
358.64 2022-01-01 00:00:00
354.98 2022-01-01 00:10:00
350.34 2022-01-01 00:20:00
429.69 2022-01-01 00:30:00
420.41 2022-01-01 00:40:00
413.82 2022-01-01 00:50:00
409.42 2022-01-01 01:00:00
409.67 2022-01-01 01:10:00
413.33 2022-01-01 01:20:00
405.03 2022-01-01 01:30:00
1209.42 2022-01-01 01:40:00
405.03 2022-01-01 01:50:00
404.54 2022-01-01 02:00:00
405.27 2022-01-01 02:10:00
999.45 2022-01-01 02:20:00
362.79 2022-01-01 02:30:00
349.37 2022-01-01 02:40:00
356.2 2022-01-01 02:50:00
3200.23 2022-01-01 03:00:00
348.39 2022-01-01 03:10:00

Here is my function to find out outliers for entire dataset

def outlierDetection(df): 
   inter_quantile_range = df.approxQuantile("pressure",[0.20,0.80],relativeError=0)
    
    Q1=inter_quantile_range[0]
    Q3=inter_quantile_range[1]
        
    inter_quantile_diff = Q3 - Q1

    minimum_Q1 =  Q1 - 1.5 * inter_quantile_diff
    maximum_Q3 =  Q3 + 1.5 * inter_quantile_diff

    df= df.withColumn("isOutlier",F.when((df["pressure"] > maximum_Q3) | (df["pressure"] < minimum_Q1), 1).otherwise(0))
    return df

It is working as expected. but it is considering the outliers for all the values which doesn't fit in the range.

I want to check outlier present for each hourly interval.

I have created another column which has hourly value as follows

pressure Timestamp date_hour
358.64 2022-01-01 00:00:00 2022-01-01 00
354.98 2022-01-01 00:10:00 2022-01-01 00
350.34 2022-01-01 00:20:00 2022-01-01 00
429.69 2022-01-01 00:30:00 2022-01-01 00
420.41 2022-01-01 00:40:00 2022-01-01 00
413.82 2022-01-01 00:50:00 2022-01-01 00
409.42 2022-01-01 01:00:00 2022-01-01 01
409.67 2022-01-01 01:10:00 2022-01-01 01
413.33 2022-01-01 01:20:00 2022-01-01 01
405.03 2022-01-01 01:30:00 2022-01-01 01

I am trying to create a window like below.

w1= Window.partitionBy("date_hour").orderBy("Timestamp")

Is there any way to use my function over each window in the dataframe?

1

There are 1 best solutions below

1
On BEST ANSWER

If you're using spark 3.1+, you can use percentile_approx to calculate the quantiles, and do rest of the calculations in pyspark. In case your spark version does not have that function, we can use an UDF that uses numpy.quantile for the quantile calculation. I've shown both in the code.

data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['pressure', 'ts']). \
    withColumn('ts', func.col('ts').cast('timestamp')). \
    withColumn('dt_hr', func.date_format('ts', 'yyyyMMddHH'))

# +--------+-------------------+----------+
# |pressure|                 ts|     dt_hr|
# +--------+-------------------+----------+
# |  358.64|2022-01-01 00:00:00|2022010100|
# |  354.98|2022-01-01 00:10:00|2022010100|
# |  350.34|2022-01-01 00:20:00|2022010100|
# |  429.69|2022-01-01 00:30:00|2022010100|
# |  420.41|2022-01-01 00:40:00|2022010100|
# |  413.82|2022-01-01 00:50:00|2022010100|
# |  409.42|2022-01-01 01:00:00|2022010101|
# |  409.67|2022-01-01 01:10:00|2022010101|
# |  413.33|2022-01-01 01:20:00|2022010101|
# |  405.03|2022-01-01 01:30:00|2022010101|
# | 1209.42|2022-01-01 01:40:00|2022010101|
# |  405.03|2022-01-01 01:50:00|2022010101|
# +--------+-------------------+----------+

getting the quantiles (showing both methods; use whichever is available)

# spark 3.1+ has percentile_approx
pressure_quantile_sdf = data_sdf. \
    groupBy('dt_hr'). \
    agg(func.percentile_approx('pressure', [0.2, 0.8]).alias('quantile_20_80'))

# +----------+----------------+
# |     dt_hr|  quantile_20_80|
# +----------+----------------+
# |2022010100|[354.98, 420.41]|
# |2022010101|[405.03, 413.33]|
# +----------+----------------+

# lower versions use UDF
def numpy_quantile_20_80(list_col):
    import numpy as np

    q_20 = np.quantile(list_col, 0.2)
    q_80 = np.quantile(list_col, 0.8)

    return [float(q_20), float(q_80)]

numpy_quantile_20_80_udf = func.udf(numpy_quantile_20_80, ArrayType(FloatType()))

pressure_quantile_sdf = data_sdf. \
    groupBy('dt_hr'). \
    agg(func.collect_list('pressure').alias('pressure_list')). \
    withColumn('quantile_20_80', numpy_quantile_20_80_udf(func.col('pressure_list')))

# +----------+--------------------+----------------+
# |     dt_hr|       pressure_list|  quantile_20_80|
# +----------+--------------------+----------------+
# |2022010100|[358.64, 354.98, ...|[354.98, 420.41]|
# |2022010101|[409.42, 409.67, ...|[405.03, 413.33]|
# +----------+--------------------+----------------+

outlier calculation would be easy with the quantile info

pressure_quantile_sdf = pressure_quantile_sdf. \
    withColumn('quantile_20', func.col('quantile_20_80')[0]). \
    withColumn('quantile_80', func.col('quantile_20_80')[1]). \
    withColumn('min_q_20', func.col('quantile_20') - 1.5 * (func.col('quantile_80') - func.col('quantile_20'))). \
    withColumn('max_q_80', func.col('quantile_80') + 1.5 * (func.col('quantile_80') - func.col('quantile_20'))). \
    select('dt_hr', 'min_q_20', 'max_q_80')

# +----------+------------------+------------------+
# |     dt_hr|          min_q_20|          max_q_80|
# +----------+------------------+------------------+
# |2022010100|256.83502197265625| 518.5549926757812|
# |2022010101|392.58001708984375|425.77996826171875|
# +----------+------------------+------------------+

# outlier calc -- select columns that are required
data_sdf. \
    join(pressure_quantile_sdf, 'dt_hr', 'left'). \
    withColumn('is_outlier', ((func.col('pressure') > func.col('max_q_80')) | (func.col('pressure') < func.col('min_q_20'))).cast('int')). \
    show()

# +----------+--------+-------------------+------------------+------------------+----------+
# |     dt_hr|pressure|                 ts|          min_q_20|          max_q_80|is_outlier|
# +----------+--------+-------------------+------------------+------------------+----------+
# |2022010100|  358.64|2022-01-01 00:00:00|256.83502197265625| 518.5549926757812|         0|
# |2022010100|  354.98|2022-01-01 00:10:00|256.83502197265625| 518.5549926757812|         0|
# |2022010100|  350.34|2022-01-01 00:20:00|256.83502197265625| 518.5549926757812|         0|
# |2022010100|  429.69|2022-01-01 00:30:00|256.83502197265625| 518.5549926757812|         0|
# |2022010100|  420.41|2022-01-01 00:40:00|256.83502197265625| 518.5549926757812|         0|
# |2022010100|  413.82|2022-01-01 00:50:00|256.83502197265625| 518.5549926757812|         0|
# |2022010101|  409.42|2022-01-01 01:00:00|392.58001708984375|425.77996826171875|         0|
# |2022010101|  409.67|2022-01-01 01:10:00|392.58001708984375|425.77996826171875|         0|
# |2022010101|  413.33|2022-01-01 01:20:00|392.58001708984375|425.77996826171875|         0|
# |2022010101|  405.03|2022-01-01 01:30:00|392.58001708984375|425.77996826171875|         0|
# |2022010101| 1209.42|2022-01-01 01:40:00|392.58001708984375|425.77996826171875|         1|
# |2022010101|  405.03|2022-01-01 01:50:00|392.58001708984375|425.77996826171875|         0|
# +----------+--------+-------------------+------------------+------------------+----------+