I'm trying to define a custom aggregation function which takes a StructType field as an input, using the Aggregator API with Dataframes. Spark version is 3.1.2.
Here's a reduced example (basic one-field case class, being passed in as a Row and reduced with a homemade first()
aggregation):
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoder, Row, functions => f}
case class MyClass(
myField: String
)
class MyFirst extends Aggregator[MyClass, MyClass, MyClass] {
override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def zero: MyClass = null
override def reduce(b: MyClass, a: MyClass): MyClass = {
if (b == null) return a
b
}
override def merge(b1: MyClass, b2: MyClass): MyClass = {
reduce(b1, b2)
}
override def finish(reduction: MyClass): MyClass = {
reduction
}
}
class UDFsTestSuite extends SparkFunSuite with SharedSparkSession {
test("test myFirst") {
val input_df = spark.createDataFrame(
spark.sparkContext.parallelize(
Seq(
Row(
Row("a")
),
Row(
Row("a")
),
Row(
Row("b")
)
)
),
StructType(
List(
StructField(
"myClass",
StructType(
List(
StructField("myField", StringType)
)
)
)
)
)
)
val myFirst = f.udaf(new MyFirst())
input_df.select(myFirst(f.col("myClass")).as("result")).show()
}
}
This test throws the following analysis error. For some reason, the analyzer says it expects a String input, seemingly due to that being the type of the first field in MyClass (if I define MyClass with an int field, then the error message changes to expecting an Int):
cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst@d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false
org.apache.spark.sql.AnalysisException: cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst@d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:161)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:152)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:342)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:342)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsUp$1(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:127)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$3(QueryPlan.scala:132)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:132)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$4(QueryPlan.scala:137)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:137)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1(CheckAnalysis.scala:152)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1$adapted(CheckAnalysis.scala:93)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:184)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:93)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:90)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:155)
at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:176)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:228)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:173)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:73)
at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:143)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:143)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:73)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:71)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:63)
at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:90)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:88)
at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:3715)
at org.apache.spark.sql.Dataset.select(Dataset.scala:1462)
at utilities.UDFsTestSuite.$anonfun$new$1(UDFsTestSuite.scala:70)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
at org.scalatest.Transformer.apply(Transformer.scala:22)
at org.scalatest.Transformer.apply(Transformer.scala:20)
at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:190)
at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:176)
at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:188)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:200)
at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:200)
at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:182)
at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:61)
at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:61)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:233)
at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
at scala.collection.immutable.List.foreach(List.scala:431)
at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)
at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475)
at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:233)
at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:232)
at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1563)
at org.scalatest.Suite.run(Suite.scala:1112)
at org.scalatest.Suite.run$(Suite.scala:1094)
at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1563)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:237)
at org.scalatest.SuperEngine.runImpl(Engine.scala:535)
at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:237)
at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:236)
at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:61)
at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213)
at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210)
at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208)
at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:61)
at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:45)
at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13(Runner.scala:1320)
at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13$adapted(Runner.scala:1314)
at scala.collection.immutable.List.foreach(List.scala:431)
at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:1314)
at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24(Runner.scala:993)
at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24$adapted(Runner.scala:971)
at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:1480)
at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:971)
at org.scalatest.tools.Runner$.run(Runner.scala:798)
at org.scalatest.tools.Runner.run(Runner.scala)
at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2or3(ScalaTestRunner.java:38)
at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:25)
Oddly, if I just wrap the input objects in an array before passing them to the function, and adjust the function accordingly, it works without issue:
class MyFirst2 extends Aggregator[Seq[MyClass], MyClass, MyClass] {
override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def zero: MyClass = null
override def reduce(b: MyClass, a: Seq[MyClass]): MyClass = {
if (b == null) return a.head
b
}
override def merge(b1: MyClass, b2: MyClass): MyClass = {
reduce(b1, Seq(b2))
}
override def finish(reduction: MyClass): MyClass = {
reduction
}
}
class UDFsTestSuiteArray extends SparkFunSuite with SharedSparkSession {
test("test myFirst") {
val input_df = spark.createDataFrame(
spark.sparkContext.parallelize(
Seq(
Row(
Row("a")
),
Row(
Row("a")
),
Row(
Row("b")
)
)
),
StructType(
List(
StructField(
"myClass",
StructType(
List(
StructField("myField", StringType)
)
)
)
)
)
)
def myFirst(col: Column) = {
val func = f.udaf(new MyFirst2())
func(f.array(col))
}
input_df.select(myFirst(f.col("myClass")).as("result")).show()
}
}
Output:
+------+
|result|
+------+
| {a}| // returns first element as expected
+------+
Any ideas for why the first example is not working? Is this a possible Spark bug or am I just misunderstanding something about how the Aggregator API is meant to work?