Key corresponding to max value in a spark map column

1.4k Views Asked by At

If I have a spark map column from string to double, is there an easy way to generate a new column with the key corresponding to the maximum value?

I was able to achieve it using collection functions as illustrated below:

import org.apache.spark.sql.functions._

val mockedDf = Seq(1, 2, 3)
  .toDF("id")
  .withColumn("optimized_probabilities_map", typedLit(Map("foo"->0.34333337, "bar"->0.23)))
val df = mockedDf
  .withColumn("optimizer_probabilities", map_values($"optimized_probabilities_map"))
  .withColumn("max_probability", array_max($"optimizer_probabilities"))
  .withColumn("max_position", array_position($"optimizer_probabilities", $"max_probability"))
  .withColumn("optimizer_ruler_names", map_keys($"optimized_probabilities_map"))
  .withColumn("optimizer_ruler_name", $"optimizer_ruler_names"( $"max_position"))

However, this solution is unnecessarly long and not very efficient. There is also a possible precision issue since I am comparing doubles when using array_position. I wonder if there is a better way to do this without UDFs, maybe using an expression string.

2

There are 2 best solutions below

0
On BEST ANSWER

Sine you can use Spark 2.4+, one way is to use Spark-SQL builtin function aggregate where we iterate through all map_keys and then compare the corresponding map_values with the buffered values acc.val and then update acc.name accordingly:

mockedDf.withColumn("optimizer_ruler_name", expr("""
    aggregate(
      map_keys(optimized_probabilities_map), 
      (string(NULL) as name, double(NULL) as val),
      (acc, y) ->
        IF(acc.val is NULL OR acc.val < optimized_probabilities_map[y]
        , (y as name, optimized_probabilities_map[y] as val)
        , acc
        ),
      acc -> acc.name
    )
""")).show(false)
+---+--------------------------------+--------------------+
|id |optimized_probabilities_map     |optimizer_ruler_name|
+---+--------------------------------+--------------------+
|1  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
|2  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
|3  |[foo -> 0.34333337, bar -> 0.23]|foo                 |
+---+--------------------------------+--------------------+
0
On

Another solution would be to explode the map column and then use Window function to get the max value like this:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"id")

val df = mockedDf.select($"id", $"optimized_probabilities_map", explode($"optimized_probabilities_map"))
                 .withColumn("max_value", max($"value").over(w))
                 .where($"max_value" === $"value")
                 .drop("value", "max_value")