Issues defining an Aggregator with case class input

175 Views Asked by At

I'm trying to define a custom aggregation function which takes a StructType field as an input, using the Aggregator API with Dataframes. Spark version is 3.1.2.

Here's a reduced example (basic one-field case class, being passed in as a Row and reduced with a homemade first() aggregation):

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoder, Row, functions => f}

case class MyClass(
    myField: String
)

class MyFirst extends Aggregator[MyClass, MyClass, MyClass] {
  override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]

  override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]

  override def zero: MyClass = null

  override def reduce(b: MyClass, a: MyClass): MyClass = {
    if (b == null) return a
    b
  }

  override def merge(b1: MyClass, b2: MyClass): MyClass = {
    reduce(b1, b2)
  }

  override def finish(reduction: MyClass): MyClass = {
    reduction
  }
}

class UDFsTestSuite extends SparkFunSuite with SharedSparkSession {

  test("test myFirst") {
    val input_df = spark.createDataFrame(
      spark.sparkContext.parallelize(
        Seq(
          Row(
            Row("a")
          ),
          Row(
            Row("a")
          ),
          Row(
            Row("b")
          )
        )
      ),
      StructType(
        List(
          StructField(
            "myClass",
            StructType(
              List(
                StructField("myField", StringType)
              )
            )
          )
        )
      )
    )

    val myFirst = f.udaf(new MyFirst())

    input_df.select(myFirst(f.col("myClass")).as("result")).show()
  }
}

This test throws the following analysis error. For some reason, the analyzer says it expects a String input, seemingly due to that being the type of the first field in MyClass (if I define MyClass with an int field, then the error message changes to expecting an Int):

cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst@d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false

org.apache.spark.sql.AnalysisException: cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst@d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false

    at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:161)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:152)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:342)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:342)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsUp$1(QueryPlan.scala:104)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:116)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:116)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:127)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$3(QueryPlan.scala:132)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at scala.collection.TraversableLike.map(TraversableLike.scala:286)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:132)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$4(QueryPlan.scala:137)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:137)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:104)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1(CheckAnalysis.scala:152)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1$adapted(CheckAnalysis.scala:93)
    at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:184)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:93)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:90)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:155)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:176)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:228)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:173)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:73)
    at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:143)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:143)
    at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:73)
    at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:71)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:63)
    at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:90)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:88)
    at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:3715)
    at org.apache.spark.sql.Dataset.select(Dataset.scala:1462)
    at utilities.UDFsTestSuite.$anonfun$new$1(UDFsTestSuite.scala:70)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
    at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
    at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
    at org.scalatest.Transformer.apply(Transformer.scala:22)
    at org.scalatest.Transformer.apply(Transformer.scala:20)
    at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:190)
    at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:176)
    at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:188)
    at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:200)
    at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
    at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:200)
    at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:182)
    at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:61)
    at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
    at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
    at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:61)
    at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:233)
    at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
    at scala.collection.immutable.List.foreach(List.scala:431)
    at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
    at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)
    at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475)
    at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:233)
    at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:232)
    at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1563)
    at org.scalatest.Suite.run(Suite.scala:1112)
    at org.scalatest.Suite.run$(Suite.scala:1094)
    at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1563)
    at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:237)
    at org.scalatest.SuperEngine.runImpl(Engine.scala:535)
    at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:237)
    at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:236)
    at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:61)
    at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213)
    at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210)
    at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208)
    at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:61)
    at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:45)
    at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13(Runner.scala:1320)
    at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13$adapted(Runner.scala:1314)
    at scala.collection.immutable.List.foreach(List.scala:431)
    at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:1314)
    at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24(Runner.scala:993)
    at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24$adapted(Runner.scala:971)
    at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:1480)
    at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:971)
    at org.scalatest.tools.Runner$.run(Runner.scala:798)
    at org.scalatest.tools.Runner.run(Runner.scala)
    at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2or3(ScalaTestRunner.java:38)
    at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:25)

Oddly, if I just wrap the input objects in an array before passing them to the function, and adjust the function accordingly, it works without issue:

class MyFirst2 extends Aggregator[Seq[MyClass], MyClass, MyClass] {
  override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]

  override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]

  override def zero: MyClass = null

  override def reduce(b: MyClass, a: Seq[MyClass]): MyClass = {
    if (b == null) return a.head
    b
  }

  override def merge(b1: MyClass, b2: MyClass): MyClass = {
    reduce(b1, Seq(b2))
  }

  override def finish(reduction: MyClass): MyClass = {
    reduction
  }
}

class UDFsTestSuiteArray extends SparkFunSuite with SharedSparkSession {

  test("test myFirst") {
    val input_df = spark.createDataFrame(
      spark.sparkContext.parallelize(
        Seq(
          Row(
            Row("a")
          ),
          Row(
            Row("a")
          ),
          Row(
            Row("b")
          )
        )
      ),
      StructType(
        List(
          StructField(
            "myClass",
            StructType(
              List(
                StructField("myField", StringType)
              )
            )
          )
        )
      )
    )

    def myFirst(col: Column) = {
      val func = f.udaf(new MyFirst2())
      func(f.array(col))
    }

    input_df.select(myFirst(f.col("myClass")).as("result")).show()
  }
}

Output:

+------+
|result|
+------+
|   {a}| // returns first element as expected
+------+

Any ideas for why the first example is not working? Is this a possible Spark bug or am I just misunderstanding something about how the Aggregator API is meant to work?

0

There are 0 best solutions below