Apache Spark, add an "CASE WHEN ... ELSE ..." calculated column to an existing DataFrame

59.8k Views Asked by At

I'm trying to add an "CASE WHEN ... ELSE ..." calculated column to an existing DataFrame, using Scala APIs. Starting dataframe:

color
Red
Green
Blue

Desired dataframe (SQL syntax: CASE WHEN color == Green THEN 1 ELSE 0 END AS bool):

color bool
Red   0
Green 1
Blue  0

How should I implement this logic?

4

There are 4 best solutions below

0
On BEST ANSWER

In the upcoming SPARK 1.4.0 release (should be released in the next couple of days). You can use the when/otherwise syntax:

// Create the dataframe
val df = Seq("Red", "Green", "Blue").map(Tuple1.apply).toDF("color")

// Use when/otherwise syntax
val df1 = df.withColumn("Green_Ind", when($"color" === "Green", 1).otherwise(0))

If you are using SPARK 1.3.0 you can chose to use a UDF:

// Define the UDF
val isGreen = udf((color: String) => {
  if (color == "Green") 1
  else 0
})
val df2 = df.withColumn("Green_Ind", isGreen($"color"))
0
On

I was looking for that long time so here is example of SPARK 2.1 JAVA with group by- for other java users.

import static org.apache.spark.sql.functions.*;
 //...
    Column uniqTrue = col("uniq").equalTo(true);
    Column uniqFalse = col("uniq").equalTo(false);

    Column testModeFalse = col("testMode").equalTo(false);
    Column testModeTrue = col("testMode").equalTo(true);

    Dataset<Row> x = basicEventDataset
            .groupBy(col(group_field))
            .agg(
                    sum(when((testModeTrue).and(uniqTrue), 1).otherwise(0)).as("tt"),
                    sum(when((testModeFalse).and(uniqTrue), 1).otherwise(0)).as("ft"),
                    sum(when((testModeTrue).and(uniqFalse), 1).otherwise(0)).as("tf"),
                    sum(when((testModeFalse).and(uniqFalse), 1).otherwise(0)).as("ff")
            );
0
On

I found this:

https://issues.apache.org/jira/browse/SPARK-3813

Worked for me on spark 2.1.0:

import sqlContext._
val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
rdd.registerTempTable("records")
println("Result of SELECT *:")
sql("SELECT case key when '93' then 'ravi' else key end FROM records").collect()
1
On

In Spark 1.5.0: you can also use the SQL syntax expr function

val df3 = df.withColumn("Green_Ind", expr("case when color = 'green' then 1 else 0 end"))

or plain spark-sql

df.registerTempTable("data")
val df4 = sql(""" select *, case when color = 'green' then 1 else 0 end as Green_ind from data """)