Spark In Depth: Group by with Case statement
I recently encountered a complex problem while working on my project. I needed to group the data for a my use case, which involved several conditions and aggregation requirements. I had to calculate averages for some records and sums for others. The input-output data is shown in the screenshot below. After exploring different options, I came up with the solution below with only one data time data shuffle.
One naïve solution could involve filtering the data based on grouping conditions, creating two data frames, performing 2 group operations, and then combining the outputs. However, this approach would require three shuffles. To minimize shuffling, I decided to go with case in group by statement.
Spark Code —
package org.example
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object Ex_GroupByCase {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").
appName("GroupByCase").getOrCreate()
import spark.implicits._
val inputData = Seq(
(1,"sum",100),(1,"sum",100),(1,"avg",50),(1,"avg",60),
(2,"sum",200),(2,"sum",200),(2,"avg",20),(2,"avg",20)).
toDF("id","op","value");
val outputDataGroup = inputData.groupBy(
when(col("id")%2===0, "even").otherwise("odd") as "oddEven",
when(col("op")==="", col("op")).otherwise(col("op")) as "op" ).
agg(coalesce(sum(when(col("op") ==="sum",col("value"))),
avg(when(col("op") ==="avg",col("value")))) as "outcome")
outputDataGroup.show(false)
spark.stop()
}
}
Spark code with DAG :
We solved the problem by using a single group by. As a result, we were able to find a solution with a single shuffle and sort. As shown in the image below, the group by triggered data shuffling, resulting in the creation of two stages.