Spark In Depth: Group by with Case statement

Sujit J Fulse
2 min readFeb 24, 2024

--

I recently encountered a complex problem while working on my project. I needed to group the data for a my use case, which involved several conditions and aggregation requirements. I had to calculate averages for some records and sums for others. The input-output data is shown in the screenshot below. After exploring different options, I came up with the solution below with only one data time data shuffle.

One naïve solution could involve filtering the data based on grouping conditions, creating two data frames, performing 2 group operations, and then combining the outputs. However, this approach would require three shuffles. To minimize shuffling, I decided to go with case in group by statement.

Input and output

Spark Code —

package org.example
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object Ex_GroupByCase {

def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").
appName("GroupByCase").getOrCreate()

import spark.implicits._

val inputData = Seq(
(1,"sum",100),(1,"sum",100),(1,"avg",50),(1,"avg",60),
(2,"sum",200),(2,"sum",200),(2,"avg",20),(2,"avg",20)).
toDF("id","op","value");

val outputDataGroup = inputData.groupBy(
when(col("id")%2===0, "even").otherwise("odd") as "oddEven",
when(col("op")==="", col("op")).otherwise(col("op")) as "op" ).
agg(coalesce(sum(when(col("op") ==="sum",col("value"))),
avg(when(col("op") ==="avg",col("value")))) as "outcome")

outputDataGroup.show(false)
spark.stop()
}
}

Spark code with DAG :

We solved the problem by using a single group by. As a result, we were able to find a solution with a single shuffle and sort. As shown in the image below, the group by triggered data shuffling, resulting in the creation of two stages.

spark code with stages

--

--

Sujit J Fulse

I am Lead Data Engineer. I have experience in building end to end data pipeline. please connect me https://www.linkedin.com/in/sujit-j-fulse