I am trying to use Spark UDAF to summarize two existing columns into a new column. Most of the tutorials on Spark UDAF out there use indices to get the values in each column of the input Row. Like this:
input.getAs[String](1)
, which is used in my update method (override def update(buffer: MutableAggregationBuffer, input: Row): Unit
). It works in my case as well. However I want to use the field name of the that column to get that value. Like this:
input.getAs[String](ColumnNames.BehaviorType)
, where ColumnNames.BehaviorType is a String object defined in an object:
/**
* Column names in the original dataset
*/
object ColumnNames {
val JobSeekerID = "JobSeekerID"
val JobID = "JobID"
val Date = "Date"
val BehaviorType = "BehaviorType"
}
This time it does not work. I got the following exception:
java.lang.IllegalArgumentException: Field "BehaviorType" does not exist. at org.apache.spark.sql.types.StructType$$anonfun$fieldIndex$1.apply(StructType.scala:292) ... at org.apache.spark.sql.Row$class.getAs(Row.scala:333) at org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165) at com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)
Some relevant code segments:
This is part of my UDAF:
class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("JobID", IntegerType) ::
StructField("BehaviorType", StringType) :: Nil)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
println(input.schema.treeString)
println
println(input.mkString(","))
println
println(this.inputSchema.treeString)
// println
// println(bufferSchema.treeString)
input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
case BehaviourTypes.viewed_job =>
buffer(0) =
buffer.getAs[Seq[Int]](0) :+ //Array[Int] //TODO WHY??
input.getAs[Int](0) //ColumnNames.JobID
case BehaviourTypes.bookmarked_job =>
buffer(1) =
buffer.getAs[Seq[Int]](1) :+ //Array[Int]
input.getAs[Int](0)//ColumnNames.JobID
case BehaviourTypes.applied_job =>
buffer(2) =
buffer.getAs[Seq[Int]](2) :+ //Array[Int]
input.getAs[Int](0) //ColumnNames.JobID
}
}
The following is the part of codes that call the UDAF:
val ubrUDAF = new UserBehaviorRecordsUDAF
val userProfileDF = userBehaviorDS
.groupBy(ColumnNames.JobSeekerID)
.agg(
ubrUDAF(
userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
).as("profile str"))
It seems the field names in the schema of the input Row are not passed into the UDAF:
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
|-- input0: integer (nullable = true)
|-- input1: string (nullable = true)
30917,viewed_job
root
|-- JobID: integer (nullable = true)
|-- BehaviorType: string (nullable = true)
What is the problem in my codes?
I also want to use the field names from my inputSchema in my update method to create maintainable code.
Ultimately switched to Aggregator which ran in half the time.