How to perform average over months using window function with null values in between?

74 Views Asked by At

I have a dataframe like below

df = spark.createDataFrame(
  [(1,1,10), (2,1,10), (3,1,None),(4,1,10),(5,1,10),(6,1,20)  \
   ,(7,1,20), (1,2,10),(2,2,10),(3,2,10),(4,2,20),(5,2,20)],
  ["Month","customer","amount"])


windowPartition = Window.partitionBy("customer").orderBy("Month").rangeBetween(Window.currentRow-5,Window.currentRow )
df=df.withColumn("avg_6_month",avg('amount').over(windowPartition))
display(df.orderBy("customer","Month"))

enter image description here

and I want to perform a average over 6 months of data only when there is no nulls in between. I was able to achieve the below results using window function where nulls are ignored. For customer-1 even if there is a null value average is calculated ignoring nulls. For customer-2, there is only 5 months of data and still its trying to calculate average.

enter image description here

Since I want to calculate average only when there is 6 continuous variables without nulls, I created a count variable and calculated average only when the count is greater than or equal to 6 and the result is

df2 = df.groupBy("customer").agg({"amount":"count"}).withColumnRenamed("count(amount)", "Amount_count" )
df= df.join(broadcast(df2), on='customer', how='left')

windowPartition = Window.partitionBy("customer").orderBy("month").rangeBetween(Window.currentRow-5,Window.currentRow )
df=df.withColumn("avg_6_month",avg('amount').over(windowPartition))
df=df.withColumn("avg_6_month",when(df.Amount_count >=6, avg('amount').over(windowPartition)).otherwise(None))
columns=['month', 'customer','amount','avg_6_month']
display(df.select(*columns).orderBy("customer","month"))

enter image description here

So now for customer-2 average is not calculated which is what I wanted. But for Customer-1 I still don't want the average to be calculated because there no 6 continuous months of data without a null on the amount column.

I am new to pyspark and I know how to achieve this in R

> amount <- c(10,10,NA,10,10,20,20,10)
> roll_mean(amount,  n = 3, align ="right", fill = NA)
[1]       NA       NA       NA       NA       NA 13.33333 16.66667 16.66667

I am expecting an outcome like below in Pyspark.

My actual data has many nulls across different months for many customers. So I want to calculate average only when there is no nulls in 6 continuous months. Is this possible using window function or is there any other way to achieve this result?

enter image description here

2

There are 2 best solutions below

3
On BEST ANSWER

Definitely possible! In your attempt, your amount_count is calculated per customer, but you'll want to do this per six month window. Something like this should work:

start = 5
calculate_between = 6

df = spark.createDataFrame([(1,1,10), (2,1,10), (3,1,None),(4,1,10),(5,1,10),(6,1,20),(7,1,20), (1,2,10),(2,2,10),(3,2,10),(4,2,20),(5,2,20)], \
                           ["month","customer","amount"])
windowPartition = Window.partitionBy("customer").orderBy("month").rangeBetween(Window.currentRow-start,Window.currentRow )
df=df.withColumn("six_month_has_null", max(col("amount").isNull()).over(windowPartition)
df=df.withColumn("avg_6_month",avg('amount').over(windowPartition))
df=df.withColumn("avg_6_month",when(df.six_month_has_null, None).otherwise(col("avg_6_month"))
df=df.withColumn("window_has_six_points", sum(lit(1)).over(windowPartition) == calculate_between)
df=df.withColumn("avg_6_month1",when(df.six_month_has_null | ~df.window_has_six_points, None).otherwise(col("avg_6_month")))
display(df)

Result

enter image description here

For the next data frame same codes gives

df = spark.createDataFrame(
  [(1,1,1), (2,1,2), (3,1,None),(4,1,4),(5,1,5),(6,1,6)  \
   ,(7,1,7),(8,1,8),(9,1,9),(10,1,None),(11,1,11),(12,1,12), 
    (1,2,10),(2,2,10),(3,2,10),(4,2,20),(5,2,17)],
  ["month","customer","amount"])

enter image description here

0
On

I managed to solve it myself but its a lengthy one which is not a good coding practice for a production code. If anyone has a better way to solve this then I would appreciate their idea by considering that as an answer to this question.

df = spark.createDataFrame([(1,1,1), (2,1,2), (3,1,None),(4,1,4),(5,1,5),(6,1,6),(7,1,7), (1,2,10),(2,2,10),(3,2,10),(4,2,20),(5,2,17)], \
                           ["month","customer","amount"])
start             = 5
calculate_between = 6


# Create a count variable to find number of non Null values
df2 = df.groupBy("customer").agg({"amount":"count"}).withColumnRenamed("count(amount)", "Amount_count" )
df= df.join(df2, on='customer', how='left')
windowPartition = Window.partitionBy("customer").orderBy("month").rangeBetween(Window.currentRow-start,Window.currentRow )

# Since I am interested in 6 month average, first 5 months should have NULL 
# thus I am forcing the first five months average as Null and only perform an 
# average only when the count column has more than 6 months of non NULL values.
# Then replacing NULL value with a dummy value 9876543210 and performing a
# sum over 6 months. Any sum over 6 rows with (NULL replaced as) 9876543210 can only 
# have two possibilities either 9876543210 (if the entries  are like [0,9876543210,0,0,0,0,0]) 
# or value greater than 9876543210. If satisfied then the average should be Null 
# otherwise I calculate average over 6 months (which means there is no Null over 
# 6 consecutive months

 

df=df.withColumn("avg_6_month",when((df.Amount_count >=calculate_between) & (df.month>start), avg('amount').over(windowPartition)).otherwise(None))
df=df.na.fill(value=9876543210, subset=['amount'])
df=df.withColumn("sum_6_month2",when((df.Amount_count >=calculate_between) & (df.month>start), sum('amount').over(windowPartition)).otherwise(None))
df=df.withColumn("avg_6_month_final",when(df.sum_6_month2 >=9876543210, None).otherwise(df.avg_6_month))
df=df.withColumn('amount', when(df.amount == 9876543210, None).otherwise(df.amount))

columns=['customer','month', 'amount','Amount_count','avg_6_month_final']
display(df.select(*columns).orderBy("customer","month"))

Result

enter image description here

This also works for a data frame with 12 months and more than 1 NULL value

df = spark.createDataFrame(
  [(1,1,1), (2,1,2), (3,1,None),(4,1,4),(5,1,5),(6,1,6)  \
   ,(7,1,7),(8,1,8),(9,1,9),(10,1,None),(11,1,11),(12,1,12), 
    (1,2,10),(2,2,10),(3,2,10),(4,2,20),(5,2,17)],
  ["month","customer","amount"])

enter image description here

with start=1 and calculate_between = 2 the result is

enter image description here