Pass a function with any case class return type as parameter

1.2k Views Asked by At

This might be a silly question but I've been struggling for quite some time. It is indeed similar to this question but I wasn't able to apply it in my code (duo to patterns or being a function).

I want to pass a flatMap (or map) transform function to a function argument and then proxy it to a strategy function that actually calls the df.rdd.flatMap method. I'll try to explain!

case class Order(id: String, totalValue: Double, freight: Double) 
case class Product(id: String, price: Double) 

... or any other case class, whatever one needs to transform a row into ...

The Entity class:

class Entity(path: String) = {
  ...
  def flatMap[T](mapFunction: (Row) => ArrayBuffer[T]): Entity = {
      this.getStrategy.flatMap[T](mapFunction)
      return this
  }
  def save(path: String): Unit = {
      ... write logic ...
  } 
}

An Entity might have different strategies for its methods. EntityStrategy is as follows:

abstract class EntityStrategy(private val entity: Entity,
                              private val spark: SparkSession) {
  ...
  def flatMap[T](mapFunction: (Row) => ArrayBuffer[T])
  def map[T](mapFunction: (Row) => T)
}

And one sample EntityStrategy implementation:

class SparkEntityStrategy(private val entity: Entity, private val spark: SparkSession)
  extends EntityStrategy(entity, spark) {
  ...
  override def map[T](mapFunction: Row => T): Unit = {
    val rdd = this.getData.rdd.map(f = mapFunction)
    this.dataFrame = this.spark.createDataFrame(rdd)
  }

  override def flatMap[T](mapFunction: (Row) => ArrayBuffer[T]): Unit = {
    var rdd = this.getData.rdd.flatMap(f = mapFunction)
    this.dataFrame = this.spark.createDataFrame(rdd)
  }
}

Finally, I would like to create a flatMap/map function and call it like this:

def transformFlatMap(row: Row): ArrayBuffer[Order] = {
    var orders = new ArrayBuffer[Order]
    var _deliveries = row.getAs[Seq[Row]]("deliveries")
    _deliveries.foreach(_delivery => {
       var order = Order(
           id = row.getAs[String]("id"),
           totalValue = _delivery.getAs("totalAmount").asInstanceOf[Double])
      orders += order
    })
   return orders
}

val entity = new Entity("path")
entity.flatMap[Order](transformFlatMap).save("path")

Of course, this doesn't work. I get an error on SparkEntityStrategy:

Error:(95, 35) No ClassTag available for T val rdd = this.getData.rdd.map(f = mapFunction)

I have tried adding an (implicit encoder: Encoder: T) to both entity and strategy methods but it was a no go. Probably done something wrong as I'm new to Scala.

If I remove the "T's" and pass an actual case class everything works fine.

1

There are 1 best solutions below

0
On

Turns out in order for both the compiler and Spark's methods to be satisfied I needed to add the following type tags:

[T <: scala.Product : ClassTag : TypeTag]

So both methods became:

def map[T <: Product : ClassTag : TypeTag](mapFunction: (Row) => T): Entity
def flatMap[T <: scala.Product : ClassTag : TypeTag](mapFunction: (Row) => TraversableOnce[T]): Entity

About scala.Product:

Base trait for all products, which in the standard library include at least scala.Product1 through scala.Product22 and therefore also their subclasses scala.Tuple1 through scala.Tuple22. In addition, all case classes implement Product with synthetically generated methods.

Since I am using a case class object as my function's return type, I needed the scala.Product so that Spark's createDataFrame could match the correct overload.

Why both ClassTag and TypeTag?

By removing the TypeTag, the compiler throws the following error:

Error:(96, 48) No TypeTag available for T this.dataFrame = this.spark.createDataFrame(rdd)

And removing the ClassTag:

Error:(95, 35) No ClassTag available for T val rdd = this.getData.rdd.map(f = mapFunction)

Adding them made both methods satisfied and everything worked as expected.

Found a good article explaining type erasure in Scala.