Spark - Calculating running sum with a threshold

89 Views Asked by At

I have a use-case where I need to compute running sum over a partition where the running sum does not exceed a certain threshold.

For example:

// Input dataset

| id | created_on  | value | running_sum  | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A  | 2021-01-01  | 1.0   | 0.0          | 10.0      |
| A  | 2021-01-02  | 2.0   | 0.0          | 10.0      |
| A  | 2021-01-03  | 8.0   | 0.0          | 10.0      |
| A  | 2021-01-04  | 5.0   | 0.0          | 10.0      |

// Output requirement

| id | created_on  | value | running_sum  | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A  | 2021-01-01  | 1.0   | 1.0          | 10.0      |
| A  | 2021-01-02  | 2.0   | 3.0          | 10.0      |
| A  | 2021-01-03  | 8.0   | 3.0          | 10.0      |
| A  | 2021-01-04  | 5.0   | 8.0          | 10.0      |

Here, threshold for any id will be same for all rows with that id. Please note that the 3rd row was skipped from summing up because the running_sum would have exceeded the threshold value. But 4th row was added since the running_sum did not exceed the threshold value.

I was able to calculate running sum without considering the threshold using window functions as follows:

final WindowSpec window = Window.partitionBy(col("id"))
                .orderBy(col("created_on").asc())
                .rowsBetween(Window.unboundedPreceding(), Window.currentRow());

dataset.withColumn("running_sum", sum(col("value")).over(window)).show();

// Output
| id | created_on  | value | running_sum  | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A  | 2021-01-01  | 1.0   | 1.0          | 10.0      |
| A  | 2021-01-02  | 2.0   | 3.0          | 10.0      |
| A  | 2021-01-03  | 8.0   | 11.0         | 10.0      |
| A  | 2021-01-04  | 5.0   | 16.0         | 10.0      |

I tried using when() with the window and also tried lag(), but it gave me unexpected results.

// With just sum over window
final WindowSpec window = Window.partitionBy(col("id"))
                .orderBy(col("created_on").asc())
                .rowsBetween(Window.unboundedPreceding(), Window.currentRow());

dataset.withColumn("running_sum", 
            when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
                .otherwise(sum(col("value")).over(window).minus(col("value")))
        ).show();

// Output
| id | created_on  | value | running_sum  | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A  | 2021-01-01  | 1.0   | 1.0          | 10.0      |
| A  | 2021-01-02  | 2.0   | 3.0          | 10.0      |
| A  | 2021-01-03  | 8.0   | 3.0          | 10.0      |
| A  | 2021-01-04  | 5.0   | 11.0         | 10.0      |


// With combination of sum and lag
final WindowSpec lagWindow = Window.partitionBy(col("id")).orderBy(col("created_on").asc());

final WindowSpec window = Window.partitionBy(col("id"))
                .orderBy(col("created_on").asc())
                .rowsBetween(Window.unboundedPreceding(), Window.currentRow());

dataset.withColumn("running_sum", 
            when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
                .otherwise(lag(col("running_sum"), 1, 0).over(lagWindow))
        ).show();

// Output
| id | created_on  | value | running_sum  | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A  | 2021-01-01  | 1.0   | 1.0          | 10.0      |
| A  | 2021-01-02  | 2.0   | 3.0          | 10.0      |
| A  | 2021-01-03  | 8.0   | 0.0          | 10.0      |
| A  | 2021-01-04  | 5.0   | 0.0          | 10.0      |

After going through some resources over the web, I came across User Defined Aggregate Functions (UDAFs) which I believe should solve my problem.

But I prefer to implement it without using UDAFs. Please let me know if there is any other way to do this or if I'm missing something in the code that I have tried.

Thanks!

1

There are 1 best solutions below

0
On

Collect all values for an id in an array and then use aggregate to sum conditionally over the array:

import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.expressions.Window;

df = ...
df.withColumn("running_sum", collect_list("value")
                .over(Window.partitionBy("id").orderBy("created_on")))
  .withColumn("running_sum",
          expr("aggregate(running_sum, double(0), (acc,x) -> if(acc + x > threshold, acc, acc +x ))"))
  .show();

Output:

+---+-------------------+-----+-----------+---------+
| id|         created_on|value|running_sum|threshold|
+---+-------------------+-----+-----------+---------+
|  A|2021-01-01 00:00:00|  1.0|        1.0|     10.0|
|  A|2021-01-02 00:00:00|  2.0|        3.0|     10.0|
|  A|2021-01-03 00:00:00|  8.0|        3.0|     10.0|
|  A|2021-01-04 00:00:00|  5.0|        8.0|     10.0|
+---+-------------------+-----+-----------+---------+