I have written a Spark UDAF that takes as input two columns (timestamp and value) and calculates a rate of change via least squares over all data points in a given window. It works perfectly fine, the code is below (shortened to relevant pieces).
UDAF
import com.google.common.collect.ImmutableList;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
public class MinimalUDAF extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
return DataTypes.createStructType(ImmutableList.of(
DataTypes.createStructField("timestamp", DataTypes.LongType, true),
DataTypes.createStructField("value", DataTypes.DoubleType, true)));
}
@Override
public StructType bufferSchema() {
return DataTypes.createStructType(ImmutableList.of(
DataTypes.createStructField("timestamps", DataTypes.createArrayType(DataTypes.LongType), false),
DataTypes.createStructField("values", DataTypes.createArrayType(DataTypes.DoubleType), false)));
}
@Override
public DataType dataType() {
return DataTypes.DoubleType;
}
@Override
public void update(final MutableAggregationBuffer buffer, final Row input) {
final Long timestamp = input.getAs(0);
final Double value = input.getAs(1);
// ...
}
@Override
public Object evaluate(final Row buffer) {
// calculate and return rate of change
}
// ...
}
Usage
@Test
public void testMinimalUDAF() {
final WindowSpec windowSpec = Window.partitionBy("instrument_id")
.orderBy("timestamp")
// look back 24h
.rangeBetween(-1 * 60 * 60 * 24, Window.currentRow());
final Dataset<Row> dataset = sparkSession
.get()
.createDataFrame(
ImmutableList.of(
RowFactory.create(1, 1, 1671079600L),
RowFactory.create(1, 101, 1671079650L),
RowFactory.create(1, 401, 1671079700L),
RowFactory.create(2, 50, 1671079630L),
RowFactory.create(2, 60, 1671079640L)),
StructType.fromDDL("instrument_id INT, value INT, timestamp LONG"))
.withColumn(
"udaf",
new MinimalUDAF()
.apply(functions.col("timestamp"), functions.col("value"))
.over(windowSpec));
dataset.show(false);
}
Result
+-------------+-----+----------+----+
|instrument_id|value|timestamp |udaf|
+-------------+-----+----------+----+
|1 |1 |1671079600|null|
|1 |101 |1671079650|2.0 |
|1 |401 |1671079700|4.0 |
|2 |50 |1671079630|null|
|2 |60 |1671079640|1.0 |
+-------------+-----+----------+----+
As of Spark 3, UDAFs are deprecated and the recommendation is to use the Aggregator interface instead. I have seen the Java example in the Spark code base / docs but it only uses one column as parameter to the aggregator. I am struggling to adapt this for my case to make it take two columns as arguments. I tried to use MapType / Struct as the IN class but ran into issues.
Aggregator
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.types.MapType;
public class MinimalMapAggregator extends Aggregator<MapType, Average, Double> {
@Override
public Average reduce(final Average buffer, final MapType data) {
// ...
}
// ...
}
Usage
@Test
public void testMinimalMapAggregator() {
sparkSession
.get()
.udf()
.register("myUDAF", functions.udaf(new MinimalMapAggregator(), Encoders.bean(MapType.class)));
final Dataset<Row> dataset = sparkSession
.get()
.createDataFrame(
ImmutableList.of(
RowFactory.create(1, 1, 1671079600L),
RowFactory.create(1, 101, 1671079650L),
RowFactory.create(1, 401, 1671079700L),
RowFactory.create(2, 50, 1671079630L),
RowFactory.create(2, 60, 1671079640L)),
StructType.fromDDL("instrument_id INT, value INT, timestamp LONG"))
.withColumn(
"ts_val",
functions.map(
functions.lit("timestamp"),
functions.col("timestamp"),
functions.lit("value"),
functions.col("value")))
.withColumn(
"udaf", functions.expr("myUDAF(ts_val) over (partition by instrument_id order by timestamp)"));
dataset.show(false);
}
When running this, the data
variable in reduce()
is just an empty MapType, so I can't get to the column values. I am seeing the same behavior with StructType.
How do I do this properly? Thanks a lot for your help!
You can find an example of a Spark Aggregator that uses a Map as input in Apache DataFu-Spark (disclosure: I am a member of DataFu and wrote this code). Take a look at the MapSetMerge Aggregator.
The declaration looks like this: