Spark custom Aggregator with multiple columns

300 Views Asked by At

I have written a Spark UDAF that takes as input two columns (timestamp and value) and calculates a rate of change via least squares over all data points in a given window. It works perfectly fine, the code is below (shortened to relevant pieces).

UDAF

import com.google.common.collect.ImmutableList;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

public class MinimalUDAF extends UserDefinedAggregateFunction {
    @Override
    public StructType inputSchema() {
        return DataTypes.createStructType(ImmutableList.of(
                DataTypes.createStructField("timestamp", DataTypes.LongType, true),
                DataTypes.createStructField("value", DataTypes.DoubleType, true)));
    }

    @Override
    public StructType bufferSchema() {
        return DataTypes.createStructType(ImmutableList.of(
                DataTypes.createStructField("timestamps", DataTypes.createArrayType(DataTypes.LongType), false),
                DataTypes.createStructField("values", DataTypes.createArrayType(DataTypes.DoubleType), false)));
    }

    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }


    @Override
    public void update(final MutableAggregationBuffer buffer, final Row input) {
        final Long timestamp = input.getAs(0);
        final Double value = input.getAs(1);
        // ...
    }

    @Override
    public Object evaluate(final Row buffer) {
        // calculate and return rate of change
    }

    // ...
}

Usage

@Test
public void testMinimalUDAF() {
    final WindowSpec windowSpec = Window.partitionBy("instrument_id")
            .orderBy("timestamp")
            // look back 24h
            .rangeBetween(-1 * 60 * 60 * 24, Window.currentRow());
    final Dataset<Row> dataset = sparkSession
            .get()
            .createDataFrame(
                    ImmutableList.of(
                            RowFactory.create(1, 1, 1671079600L),
                            RowFactory.create(1, 101, 1671079650L),
                            RowFactory.create(1, 401, 1671079700L),
                            RowFactory.create(2, 50, 1671079630L),
                            RowFactory.create(2, 60, 1671079640L)),
                    StructType.fromDDL("instrument_id INT, value INT, timestamp LONG"))
            .withColumn(
                    "udaf",
                    new MinimalUDAF()
                            .apply(functions.col("timestamp"), functions.col("value"))
                            .over(windowSpec));
    dataset.show(false);
}

Result

+-------------+-----+----------+----+
|instrument_id|value|timestamp |udaf|
+-------------+-----+----------+----+
|1            |1    |1671079600|null|
|1            |101  |1671079650|2.0 |
|1            |401  |1671079700|4.0 |
|2            |50   |1671079630|null|
|2            |60   |1671079640|1.0 |
+-------------+-----+----------+----+

As of Spark 3, UDAFs are deprecated and the recommendation is to use the Aggregator interface instead. I have seen the Java example in the Spark code base / docs but it only uses one column as parameter to the aggregator. I am struggling to adapt this for my case to make it take two columns as arguments. I tried to use MapType / Struct as the IN class but ran into issues.

Aggregator

import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.types.MapType;

public class MinimalMapAggregator extends Aggregator<MapType, Average, Double> {

    @Override
    public Average reduce(final Average buffer, final MapType data) {
        // ...
    }

    // ...
}

Usage

@Test
public void testMinimalMapAggregator() {
    sparkSession
            .get()
            .udf()
            .register("myUDAF", functions.udaf(new MinimalMapAggregator(), Encoders.bean(MapType.class)));

    final Dataset<Row> dataset = sparkSession
            .get()
            .createDataFrame(
                    ImmutableList.of(
                            RowFactory.create(1, 1, 1671079600L),
                            RowFactory.create(1, 101, 1671079650L),
                            RowFactory.create(1, 401, 1671079700L),
                            RowFactory.create(2, 50, 1671079630L),
                            RowFactory.create(2, 60, 1671079640L)),
                    StructType.fromDDL("instrument_id INT, value INT, timestamp LONG"))
            .withColumn(
                    "ts_val",
                    functions.map(
                            functions.lit("timestamp"),
                            functions.col("timestamp"),
                            functions.lit("value"),
                            functions.col("value")))
            .withColumn(
                    "udaf", functions.expr("myUDAF(ts_val) over (partition by instrument_id order by timestamp)"));
    dataset.show(false);
}

When running this, the data variable in reduce() is just an empty MapType, so I can't get to the column values. I am seeing the same behavior with StructType.

How do I do this properly? Thanks a lot for your help!

1

There are 1 best solutions below

0
On

You can find an example of a Spark Aggregator that uses a Map as input in Apache DataFu-Spark (disclosure: I am a member of DataFu and wrote this code). Take a look at the MapSetMerge Aggregator.

The declaration looks like this:

/**
  * Performs a deep merge of maps of kind string -> set<string>
  */
  class MapSetMerge extends Aggregator[Map[String, Array[String]], Map[String, scala.collection.immutable.Set[String]], Map[String, Array[String]]] with Serializable {