Spark UDAF: How to get value from input by column field name in UDAF (User-Defined Aggregation Function)?

1.1k Views Asked by At

I am trying to use Spark UDAF to summarize two existing columns into a new column. Most of the tutorials on Spark UDAF out there use indices to get the values in each column of the input Row. Like this:

input.getAs[String](1)

, which is used in my update method (override def update(buffer: MutableAggregationBuffer, input: Row): Unit). It works in my case as well. However I want to use the field name of the that column to get that value. Like this:

input.getAs[String](ColumnNames.BehaviorType)

, where ColumnNames.BehaviorType is a String object defined in an object:

 /**
    * Column names in the original dataset
    */
  object ColumnNames {
    val JobSeekerID = "JobSeekerID"
    val JobID = "JobID"
    val Date = "Date"
    val BehaviorType = "BehaviorType"
  }

This time it does not work. I got the following exception:

java.lang.IllegalArgumentException: Field "BehaviorType" does not exist. at org.apache.spark.sql.types.StructType$$anonfun$fieldIndex$1.apply(StructType.scala:292) ... at org.apache.spark.sql.Row$class.getAs(Row.scala:333) at org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165) at com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)

Some relevant code segments:

This is part of my UDAF:

class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {


  override def inputSchema: StructType = StructType(
    StructField("JobID", IntegerType) ::
      StructField("BehaviorType", StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
    println(input.schema.treeString)
    println
    println(input.mkString(","))
    println
    println(this.inputSchema.treeString)
//    println
//    println(bufferSchema.treeString)

    input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
      case BehaviourTypes.viewed_job =>
        buffer(0) =
          buffer.getAs[Seq[Int]](0) :+ //Array[Int]  //TODO WHY??
          input.getAs[Int](0) //ColumnNames.JobID
      case BehaviourTypes.bookmarked_job =>
        buffer(1) =
          buffer.getAs[Seq[Int]](1) :+ //Array[Int]
            input.getAs[Int](0)//ColumnNames.JobID
      case BehaviourTypes.applied_job =>
        buffer(2) =
          buffer.getAs[Seq[Int]](2) :+  //Array[Int]
            input.getAs[Int](0) //ColumnNames.JobID
    }
  }

The following is the part of codes that call the UDAF:

val ubrUDAF = new UserBehaviorRecordsUDAF

val userProfileDF = userBehaviorDS
  .groupBy(ColumnNames.JobSeekerID)
  .agg(
    ubrUDAF(
      userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
      userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
    ).as("profile str"))

It seems the field names in the schema of the input Row are not passed into the UDAF:

XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
 |-- input0: integer (nullable = true)
 |-- input1: string (nullable = true)


30917,viewed_job

root
 |-- JobID: integer (nullable = true)
 |-- BehaviorType: string (nullable = true)

What is the problem in my codes?

1

There are 1 best solutions below

0
On

I also want to use the field names from my inputSchema in my update method to create maintainable code.

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)

Ultimately switched to Aggregator which ran in half the time.