Most efficient way to join two PySpark dataframes, pivot and fill NULL based on condition

28 Views Asked by At

I have two large PySpark dataframes in long format. The final table should be in wide format. I cannot figure out what is the best way to do it.

Thank you for the support.

from pyspark.sql import Row
from pyspark.sql import functions as F

data1 = [Row(Component='b1',Subcomponent='a11',Class=1),
         Row(Component='b1',Subcomponent='a12',Class=1),
         Row(Component='c1',Subcomponent='b1',Class=2),
         Row(Component='b2',Subcomponent='a21',Class=1),
         Row(Component='b2',Subcomponent='a22',Class=1),
         Row(Component='c2',Subcomponent='b2',Class=2)]
df1 = spark.createDataFrame(data1)
df1.show()

data2 = [Row(Part='a11',Parameter='X_01',Value=1101),
         Row(Part='a11',Parameter='X_02',Value=1102),
         Row(Part='a12',Parameter='X_01',Value=1201),
         Row(Part='a12',Parameter='X_02',Value=1202),
         Row(Part='b1',Parameter='Y',Value=1),
         Row(Part='c1',Parameter='Z',Value=10),
         Row(Part='a21',Parameter='X_01',Value=2101),
         Row(Part='a21',Parameter='X_02',Value=2102),
         Row(Part='a22',Parameter='X_01',Value=2201),
         Row(Part='a22',Parameter='X_02',Value=2202),
         Row(Part='b2',Parameter='Y',Value=2),
         Row(Part='c2',Parameter='Z',Value=20)]
df2 = spark.createDataFrame(data2)
df2.show()
+---------+------------+
|Component|Subcomponent|
+---------+------------+
|       b1|         a11|
|       b1|         a12|
|       c1|          b1|
|       b2|         a21|
|       b2|         a22|
|       c2|          b2|
+---------+------------+

+----+---------+-----+
|Part|Parameter|Value|
+----+---------+-----+
| a11|     X_01| 1101|
| a11|     X_02| 1102|
| a12|     X_01| 1201|
| a12|     X_02| 1202|
|  b1|        Y|    1|
|  c1|        Z|   10|
| a21|     X_01| 2101|
| a21|     X_02| 2102|
| a22|     X_01| 2201|
| a22|     X_02| 2202|
|  b2|        Y|    2|
|  c2|        Z|   20|
+----+---------+-----+

The result table should look like this:

+---------+--------------------------------------+---------+---------+---------+
|Component|                             X_01_mean|X_02_mean|        Y|        Z|
+---------+--------------------------------------+---------+---------+---------+
|       b1|                      1151 (1101+1201)|     1152|        1|     Null|
|       b2|                                  2151|     2152|        2|     Null|
|       c1|1151 (fill out since b1 is part of c1)|     1152|        1|       10|
|       c2|                                  2151|     2152|        2|       20|
+---------+--------------------------------------+---------+---------+---------+

So far my code looks like this:

# join both tables
df12 = (df1.alias('a')
  .join(df2.alias('b'),
    on=F.col('a.Subcomponent') == F.col('b.Part'),
    how='left')
  .drop(F.col('b.Component'))
)

# pivot and aggregate class == 1
df12_class1 = (df12
  .where(F.col('Class')==1)
  .groupby('Component')
  .pivot('Parameter')
  .agg(F.mean('Value').alias('mean'))
)

# pivot class == 2
df12_class2 = (df12
  .where(F.col('Class')==2)
  .groupby('Component')
  .pivot('Parameter')
  .agg(F.first('Value'))
)

# join class1
df12_join = (df12.alias('a')
  .select(['Component','Subcomponent'])
  .dropDuplicates(['Component'])
  .join(df12_class1.alias('b'),
    on=F.col('a.Component') == F.col('b.Component'),
    how='left')
  .drop(F.col('b.Component'))
)

# join class2
df12_join = (df12_join.alias('a')
  .join(df12_class2.alias('b'),
    on=F.col('a.Component') == F.col('b.Component'),
    how='left')
  .drop(F.col('b.Component'))
)
df12_join.show()

So far the table looks like this.

+---------+------------+---------+---------+----+
|Component|Subcomponent|X_01_mean|X_02_mean|   Y|
+---------+------------+---------+---------+----+
|       b1|         a11|     1151|     1152|NULL|
|       b2|         a21|     2151|     2152|NULL|
|       c1|          b1|     NULL|     NULL|   1|
|       c2|          b2|     NULL|     NULL|NULL|
+---------+------------+---------+---------+----+

Still cannot figure out how to fill the NULL and how to add the Parameter='Z'. Is there a better way to aggregate and pivot as I have done it so far?

0

There are 0 best solutions below