I'm executing the following code:
from pyspark.sql import types as T, functions as F, SparkSession
spark = SparkSession.builder.getOrCreate()
schema = T.StructType([
T.StructField("col_1", T.IntegerType(), False),
T.StructField("col_2", T.IntegerType(), False),
T.StructField("measure_1", T.FloatType(), False),
T.StructField("measure_2", T.FloatType(), False),
])
data = [
{"col_1": 1, "col_2": 2, "measure_1": 0.5, "measure_2": 1.5},
{"col_1": 2, "col_2": 3, "measure_1": 2.5, "measure_2": 3.5}
]
df = spark.createDataFrame(data, schema)
df.show()
"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
| 1| 2| 0.5| 1.5|
| 2| 3| 2.5| 3.5|
+-----+-----+---------+---------+
"""
group_cols = ["col_1", "col_2"]
measure_cols = ["measure_1", "measure_2"]
for col in measure_cols:
stats = df.groupBy(group_cols).agg(
F.max(col).alias("max_" + col),
F.avg(col).alias("avg_" + col),
)
df = df.join(stats, group_cols)
df.show()
"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
| 2| 3| 2.5| 3.5| 2.5| 2.5| 3.5| 3.5|
| 1| 2| 0.5| 1.5| 0.5| 0.5| 1.5| 1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""
Now the problem arises if my initial df
isn't so simple but is actually a series of joins or other operations. I notice when I look at my job that df seems to be derived several times as my groupBy
operations execute. The simple query plan here is:
df.explain()
"""
>>> df.explain()
== Physical Plan ==
*(11) Project [col_1#26, col_2#27, measure_1#28, measure_2#29, max_measure_1#56, avg_measure_1#58, max_measure_2#80, avg_measure_2#82]
+- *(11) SortMergeJoin [col_1#26, col_2#27], [col_1#87, col_2#88], Inner
:- *(5) Project [col_1#26, col_2#27, measure_1#28, measure_2#29, max_measure_1#56, avg_measure_1#58]
: +- *(5) SortMergeJoin [col_1#26, col_2#27], [col_1#63, col_2#64], Inner
: :- *(2) Sort [col_1#26 ASC NULLS FIRST, col_2#27 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(col_1#26, col_2#27, 200), ENSURE_REQUIREMENTS, [id=#276]
: : +- *(1) Scan ExistingRDD[col_1#26,col_2#27,measure_1#28,measure_2#29]
: +- *(4) Sort [col_1#63 ASC NULLS FIRST, col_2#64 ASC NULLS FIRST], false, 0
: +- *(4) HashAggregate(keys=[col_1#63, col_2#64], functions=[max(measure_1#65), avg(cast(measure_1#65 as double))])
: +- Exchange hashpartitioning(col_1#63, col_2#64, 200), ENSURE_REQUIREMENTS, [id=#282]
: +- *(3) HashAggregate(keys=[col_1#63, col_2#64], functions=[partial_max(measure_1#65), partial_avg(cast(measure_1#65 as double))])
: +- *(3) Project [col_1#63, col_2#64, measure_1#65]
: +- *(3) Scan ExistingRDD[col_1#63,col_2#64,measure_1#65,measure_2#66]
+- *(10) Sort [col_1#87 ASC NULLS FIRST, col_2#88 ASC NULLS FIRST], false, 0
+- *(10) HashAggregate(keys=[col_1#87, col_2#88], functions=[max(measure_2#90), avg(cast(measure_2#90 as double))])
+- *(10) HashAggregate(keys=[col_1#87, col_2#88], functions=[partial_max(measure_2#90), partial_avg(cast(measure_2#90 as double))])
+- *(10) Project [col_1#87, col_2#88, measure_2#90]
+- *(10) SortMergeJoin [col_1#87, col_2#88], [col_1#63, col_2#64], Inner
:- *(7) Sort [col_1#87 ASC NULLS FIRST, col_2#88 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(col_1#87, col_2#88, 200), ENSURE_REQUIREMENTS, [id=#293]
: +- *(6) Project [col_1#87, col_2#88, measure_2#90]
: +- *(6) Scan ExistingRDD[col_1#87,col_2#88,measure_1#89,measure_2#90]
+- *(9) Sort [col_1#63 ASC NULLS FIRST, col_2#64 ASC NULLS FIRST], false, 0
+- *(9) HashAggregate(keys=[col_1#63, col_2#64], functions=[])
+- Exchange hashpartitioning(col_1#63, col_2#64, 200), ENSURE_REQUIREMENTS, [id=#299]
+- *(8) HashAggregate(keys=[col_1#63, col_2#64], functions=[])
+- *(8) Project [col_1#63, col_2#64]
+- *(8) Scan ExistingRDD[col_1#63,col_2#64,measure_1#65,measure_2#66]
"""
But if for instance I change my above code to make the initial df
be the result of a join and union:
from pyspark.sql import types as T, functions as F, SparkSession
spark = SparkSession.builder.getOrCreate()
schema = T.StructType([
T.StructField("col_1", T.IntegerType(), False),
T.StructField("col_2", T.IntegerType(), False),
T.StructField("measure_1", T.FloatType(), False),
T.StructField("measure_2", T.FloatType(), False),
])
data = [
{"col_1": 1, "col_2": 2, "measure_1": 0.5, "measure_2": 1.5},
{"col_1": 2, "col_2": 3, "measure_1": 2.5, "measure_2": 3.5}
]
df = spark.createDataFrame(data, schema)
right_schema = T.StructType([
T.StructField("col_1", T.IntegerType(), False)
])
right_data = [
{"col_1": 1},
{"col_1": 1},
{"col_1": 2},
{"col_1": 2}
]
right_df = spark.createDataFrame(right_data, right_schema)
df = df.unionByName(df)
df = df.join(right_df, on="col_1")
df.show()
"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
| 1| 2| 0.5| 1.5|
| 1| 2| 0.5| 1.5|
| 1| 2| 0.5| 1.5|
| 1| 2| 0.5| 1.5|
| 2| 3| 2.5| 3.5|
| 2| 3| 2.5| 3.5|
| 2| 3| 2.5| 3.5|
| 2| 3| 2.5| 3.5|
+-----+-----+---------+---------+
"""
df.explain()
"""
== Physical Plan ==
*(7) Project [col_1#299, col_2#300, measure_1#301, measure_2#302, col_2#354, measure_1#355, measure_2#356]
+- *(7) SortMergeJoin [col_1#299], [col_1#353], Inner
:- *(3) Sort [col_1#299 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(col_1#299, 200), ENSURE_REQUIREMENTS, [id=#595]
: +- Union
: :- *(1) Scan ExistingRDD[col_1#299,col_2#300,measure_1#301,measure_2#302]
: +- *(2) Scan ExistingRDD[col_1#299,col_2#300,measure_1#301,measure_2#302]
+- *(6) Sort [col_1#353 ASC NULLS FIRST], false, 0
+- ReusedExchange [col_1#353, col_2#354, measure_1#355, measure_2#356], Exchange hashpartitioning(col_1#299, 200), ENSURE_REQUIREMENTS, [id=#595]
"""
group_cols = ["col_1", "col_2"]
measure_cols = ["measure_1", "measure_2"]
for col in measure_cols:
stats = df.groupBy(group_cols).agg(
F.max(col).alias("max_" + col),
F.avg(col).alias("avg_" + col),
)
df = df.join(stats, group_cols)
df.show()
"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
| 2| 3| 2.5| 3.5| 2.5| 2.5| 3.5| 3.5|
| 2| 3| 2.5| 3.5| 2.5| 2.5| 3.5| 3.5|
| 2| 3| 2.5| 3.5| 2.5| 2.5| 3.5| 3.5|
| 2| 3| 2.5| 3.5| 2.5| 2.5| 3.5| 3.5|
| 1| 2| 0.5| 1.5| 0.5| 0.5| 1.5| 1.5|
| 1| 2| 0.5| 1.5| 0.5| 0.5| 1.5| 1.5|
| 1| 2| 0.5| 1.5| 0.5| 0.5| 1.5| 1.5|
| 1| 2| 0.5| 1.5| 0.5| 0.5| 1.5| 1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""
df.explain()
"""
== Physical Plan ==
*(31) Project [col_1#404, col_2#405, measure_1#406, measure_2#407, max_measure_1#465, avg_measure_1#467, max_measure_2#489, avg_measure_2#491]
+- *(31) SortMergeJoin [col_1#404, col_2#405], [col_1#496, col_2#497], Inner
:- *(15) Project [col_1#404, col_2#405, measure_1#406, measure_2#407, max_measure_1#465, avg_measure_1#467]
: +- *(15) SortMergeJoin [col_1#404, col_2#405], [col_1#472, col_2#473], Inner
: :- *(7) Sort [col_1#404 ASC NULLS FIRST, col_2#405 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(col_1#404, col_2#405, 200), ENSURE_REQUIREMENTS, [id=#1508]
: : +- *(6) Project [col_1#404, col_2#405, measure_1#406, measure_2#407]
: : +- *(6) SortMergeJoin [col_1#404], [col_1#412], Inner
: : :- *(3) Sort [col_1#404 ASC NULLS FIRST], false, 0
: : : +- Exchange hashpartitioning(col_1#404, 200), ENSURE_REQUIREMENTS, [id=#1494]
: : : +- Union
: : : :- *(1) Scan ExistingRDD[col_1#404,col_2#405,measure_1#406,measure_2#407]
: : : +- *(2) Scan ExistingRDD[col_1#404,col_2#405,measure_1#406,measure_2#407]
: : +- *(5) Sort [col_1#412 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
: : +- *(4) Scan ExistingRDD[col_1#412]
: +- *(14) Sort [col_1#472 ASC NULLS FIRST, col_2#473 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(col_1#472, col_2#473, 200), ENSURE_REQUIREMENTS, [id=#1639]
: +- *(13) HashAggregate(keys=[col_1#472, col_2#473], functions=[max(measure_1#474), avg(cast(measure_1#474 as double))])
: +- *(13) HashAggregate(keys=[col_1#472, col_2#473], functions=[partial_max(measure_1#474), partial_avg(cast(measure_1#474 as double))])
: +- *(13) Project [col_1#472, col_2#473, measure_1#474]
: +- *(13) SortMergeJoin [col_1#472], [col_1#412], Inner
: :- *(10) Sort [col_1#472 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(col_1#472, 200), ENSURE_REQUIREMENTS, [id=#1516]
: : +- Union
: : :- *(8) Project [col_1#472, col_2#473, measure_1#474]
: : : +- *(8) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
: : +- *(9) Project [col_1#472, col_2#473, measure_1#474]
: : +- *(9) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
: +- *(12) Sort [col_1#412 ASC NULLS FIRST], false, 0
: +- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
+- *(30) Sort [col_1#496 ASC NULLS FIRST, col_2#497 ASC NULLS FIRST], false, 0
+- *(30) HashAggregate(keys=[col_1#496, col_2#497], functions=[max(measure_2#499), avg(cast(measure_2#499 as double))])
+- *(30) HashAggregate(keys=[col_1#496, col_2#497], functions=[partial_max(measure_2#499), partial_avg(cast(measure_2#499 as double))])
+- *(30) Project [col_1#496, col_2#497, measure_2#499]
+- *(30) SortMergeJoin [col_1#496, col_2#497], [col_1#472, col_2#473], Inner
:- *(22) Sort [col_1#496 ASC NULLS FIRST, col_2#497 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(col_1#496, col_2#497, 200), ENSURE_REQUIREMENTS, [id=#1660]
: +- *(21) Project [col_1#496, col_2#497, measure_2#499]
: +- *(21) SortMergeJoin [col_1#496], [col_1#412], Inner
: :- *(18) Sort [col_1#496 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(col_1#496, 200), ENSURE_REQUIREMENTS, [id=#1544]
: : +- Union
: : :- *(16) Project [col_1#496, col_2#497, measure_2#499]
: : : +- *(16) Scan ExistingRDD[col_1#496,col_2#497,measure_1#498,measure_2#499]
: : +- *(17) Project [col_1#496, col_2#497, measure_2#499]
: : +- *(17) Scan ExistingRDD[col_1#496,col_2#497,measure_1#498,measure_2#499]
: +- *(20) Sort [col_1#412 ASC NULLS FIRST], false, 0
: +- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
+- *(29) Sort [col_1#472 ASC NULLS FIRST, col_2#473 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(col_1#472, col_2#473, 200), ENSURE_REQUIREMENTS, [id=#1707]
+- *(28) HashAggregate(keys=[col_1#472, col_2#473], functions=[])
+- *(28) HashAggregate(keys=[col_1#472, col_2#473], functions=[])
+- *(28) Project [col_1#472, col_2#473]
+- *(28) SortMergeJoin [col_1#472], [col_1#412], Inner
:- *(25) Sort [col_1#472 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(col_1#472, 200), ENSURE_REQUIREMENTS, [id=#1566]
: +- Union
: :- *(23) Project [col_1#472, col_2#473]
: : +- *(23) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
: +- *(24) Project [col_1#472, col_2#473]
: +- *(24) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
+- *(27) Sort [col_1#412 ASC NULLS FIRST], false, 0
+- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
"""
You can see in the query plan that the join + union is derived several times, which is reflected in my job's execution report where I see the stage with the identical number of tasks run again and again.
How can I stop this re-derivation from happening?
The inner loop of your transform where you join + derive columns several times against a base DataFrame would benefit from PySpark's
.cache()
function. This explicitly instructs Spark to hold on to the derived DataFrame and not re-compute it. This means you will compute the initial union + join a single time, then re-use the DataFrame in subsequent transformations.This is a one-line addition that will benefit your execution massively.
You can now see in the query plan that an InMemoryTableRelation is used in place of several recurring shuffles, and your job execution will reflect as much.
Note:
.cache()
doesn't change your query plan and won't truncate it at all, it simply changes the manner in which your data is created and re-used.