Convert Spark2.2's UDAF to 3.0 Aggregator

213 Views Asked by At

I have a already written UDAF in scala using Spark2.4. Since our Databricks cluster was in 6.4 runtime whose support is no more there, we need to move to 7.3 LTS which have the long term support and uses Spark3. UDAF is deprecated in Spark3 and will be removed in future(most likely). So I am trying to convert a UDAF into Aggregator function

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType, StructField, StructType, DataType}

object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
  
  override def inputSchema: StructType = new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  def bufferSchema: StructType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Returned Data Type .
  def dataType: DataType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Self-explaining
  def deterministic: Boolean = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = null
    buffer(1) = null
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, inputRow: Row): Unit ={
      
      val inputId = inputRow.getAs[Int](0)
      val actualInputId = inputRow.get(0)
      val inputName = inputRow.getString(1)
      
      val bufferId = buffer.getAs[Int](0)
      val actualBufferId = buffer.get(0)
      val bufferName = buffer.getString(1)
      
      if(actualBufferId == null){
        buffer(0) = actualInputId
        buffer(1) = inputName
      }else if(actualInputId != null) {
        if(inputId > bufferId){
          buffer(0) = inputId
          buffer(1) = inputName
        }
      }  
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    
      val buffer1Id = buffer1.getAs[Int](0)
      val actualbuffer1Id = buffer1.get(0)
      val buffer1Name = buffer1.getString(1)
      
      val buffer2Id = buffer2.getAs[Int](0)
      val actualbuffer2Id = buffer2.get(0)
      val buffer2Name = buffer2.getString(1)
      
     if(actualbuffer1Id == null){
        buffer1(0) = actualbuffer2Id
        buffer1(1) = buffer2Name
     }else if(actualbuffer2Id != null){
        if(buffer2Id > buffer1Id){
          buffer1(0) = buffer2Id
          buffer1(1) = buffer2Name
        }
      }
    
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row): Any = {
    Row(buffer.get(0), buffer.getString(1))
  }
}

After usage this give output as :

{"id": 1282, "name": "McCormick Christmas"}

{"id": 1305, "name": "McCormick Perfect Pinch"}

{"id": 1677, "name": "Viking Cruises Viking Cruises"}

0

There are 0 best solutions below