Spark UDAF with Window function

276 Views Asked by At

In order to achieve my requirement "Process the provided data using the provided external library", I had written an UDAF using spark-scala which was working fine until I get a scenario as below:

TestWindowFunc.scala

import org.apache.spark.sql.SparkSession

object TestWindowFunc {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TestWindowFunc")
      .master("local[3]")
      .config("spark.driver.memory", "5g")
      .getOrCreate()

    spark.udf.register("custAvg", new CustAvg)

    val df = spark.read.option("delimiter", "|").option("header", "true")
      .csv("./src/main/resources/students_mark.csv")

    df.createOrReplaceTempView("testWindowFunc")

    val df1 = spark.sql("select X.*" +
      ", custAvg(ACT_MARK, OUT_OF) over (partition by STUDENT_ID order by ACT_MARK) a" +
      ", custAvg(ACT_MARK, OUT_OF) over (partition by STUDENT_ID order by ACT_MARK) b" +
      " from testWindowFunc X")

    df1.show()
  }
}

CustAvg.scala

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

class CustAvg extends UserDefinedAggregateFunction {
  var initializeCounter = 0
  var updateCounter = 0

  override def inputSchema: StructType = StructType(Array(
    StructField("act_mark", IntegerType),
    StructField("out_of", IntegerType)
  )
  )

  override def bufferSchema: StructType = StructType(Array(
    StructField("act_mark_tot", LongType),
    StructField("out_of_tot", LongType)
  ))

  override def dataType: DataType = LongType

  override def deterministic: Boolean = false

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    initializeCounter += 1
    println("initialize:::" + initializeCounter)
    updateCounter = 0

    /**
     * initializing the external library for each window
     */
    //    uncomment the below lines to execute the function
    //    buffer(0) = 0L
    //    buffer(1) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    updateCounter += 1
    println("update:::" + updateCounter)

    /**
     * sending data to the external library for each row of the respective window
     */
    //    uncomment the below lines to execute the function
    //    buffer(0) = buffer.getLong(0) + input.getInt(0)
    //    buffer(1) = buffer.getLong(1) + input.getInt(1)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    throw new Exception("Merge Not Allowed")
  }

  override def evaluate(buffer: Row): Any = {
    println("evaluate:::" + updateCounter)

    /**
     * calling the external library to process the data
     */
    //    uncomment the below line to execute the function
    //    buffer.getLong(0)
  }
}

students_mark.csv

STUDENT_ID|ACT_MARK|OUT_OF
1|70|100
1|68|100
1|90|100

Expected output

initialize:::1
update:::1
evaluate:::1
update:::2
evaluate:::2
update:::3
evaluate:::3
initialize:::2
update:::1
evaluate:::1
update:::2
evaluate:::2
update:::3
evaluate:::3

Actual output

initialize:::1
initialize:::2
update:::1
update:::2
evaluate:::2
evaluate:::2
update:::3
update:::4
evaluate:::4
evaluate:::4
update:::5
update:::6
evaluate:::6
evaluate:::6

Is this how spark behaves for this scenario or I am doing anything wrong here?

Could someone please help me on this with most appropriate explanation.

Version details:

  • scala: 2.11
  • spark: 2.4.0

Thanks in Advance.

0

There are 0 best solutions below