I have two pyspark dataframes df1 with IntegerType Column and df2 with collect_set column.
I want to join both such that for each set of df2 all the rows in df1 should come in the same group.
I have a df as below:
+--------------------------------+---+
|ID |grp|
+--------------------------------+---+
|7d693086c5b8f74cbe881166cf3c2a29|2 |
|fcb907411aff4f44c599cf03d23327c0|2 |
|7933546917973caa8c2898c834446415|1 |
|3ef2e38d48a9af3e096ddd3bc3816afb|1 |
|7e18b452bb1e2845800a71d9431033b6|3 |
|9bc9d06e0efb16abde20c35ba36a2f1b|3 |
|7e18b452bb1e2845800a71d9431033b6|4 |
|ff351ada316cbb0f270f935adfd16ad4|4 |
|8919d5fd5b6fd118c1c6b691c65c9df9|6 |
.......
+--------------------------------+---+
Another df2 as below:
+--------------------------------+-------------+
|ID |collected_grp|
+--------------------------------+-------------+
|fcb907411aff4f44c599cf03d23327c0|[2] |
|ff351ada316cbb0f270f935adfd16ad4|[16, 4] |
|9bc9d06e0efb16abde20c35ba36a2f1b|[16, 3] |
|7e18b452bb1e2845800a71d9431033b6|[16, 3, 4] |
|8919d5fd5b6fd118c1c6b691c65c9df9|[6, 7, 8] |
|484f25e9ab91af2c116cd788c91bdc82|[5] |
|8dc7dfb4466590375f1aaac7fc8cb987|[6, 8] |
|8240cf1e442a97aa91d1029270728bbb|[5] |
|9b93e3cfc5605e74ce2ce4c9450fd622|[7, 8] |
|41f007c0cc45c228e246f1cc91145878|[9, 13] |
|8f459a7cff281bad73f604166841849e|[9, 14] |
|99f70106443a6f3f5c69d99a49d22d01|[10] |
|f6da014449e6fa82c24d002b4a27b105|[9, 13, 14] |
|be73ca52536d13dfea295d4fcd273fde|[10] |
......
+--------------------------------+-------------+
I want to join df2 with df1 such that for arrray like [16,4], [16, 3, 4] all the values of each grp should be in one group.
Any help is appreciated.
below is the code for creating both the dataframes:
data = [
['7933546917973caa8c2898c834446415', '3ef2e38d48a9af3e096ddd3bc3816afb', 1],
['7d693086c5b8f74cbe881166cf3c2a29', 'fcb907411aff4f44c599cf03d23327c0', 2],
['7e18b452bb1e2845800a71d9431033b6', '9bc9d06e0efb16abde20c35ba36a2f1b', 3],
['7e18b452bb1e2845800a71d9431033b6', 'ff351ada316cbb0f270f935adfd16ad4', 4],
['8240cf1e442a97aa91d1029270728bbb', '484f25e9ab91af2c116cd788c91bdc82', 5],
['8919d5fd5b6fd118c1c6b691c65c9df9', '8dc7dfb4466590375f1aaac7fc8cb987', 6],
['8919d5fd5b6fd118c1c6b691c65c9df9', '9b93e3cfc5605e74ce2ce4c9450fd622', 7],
['8dc7dfb4466590375f1aaac7fc8cb987', '9b93e3cfc5605e74ce2ce4c9450fd622', 8],
['8f459a7cff281bad73f604166841849e', '41f007c0cc45c228e246f1cc91145878', 9],
['99f70106443a6f3f5c69d99a49d22d01', 'be73ca52536d13dfea295d4fcd273fde', 10],
['a9781767ca4fe8fb1282ee003d2c06ac', 'cb6feb2f38731fc7832545cbe2ac881b', 11],
['f4901968c29e928fc7364411b03336d4', '6fa82a51f17f0bf258fe06befc661216', 12],
['f6da014449e6fa82c24d002b4a27b105', '41f007c0cc45c228e246f1cc91145878', 13],
['f6da014449e6fa82c24d002b4a27b105', '8f459a7cff281bad73f604166841849e', 14],
['f93c0028bb26bc9b99fca1db300c2ac1', 'ccce888c5813025e95434d7ceedf1db3', 15],
['ff351ada316cbb0f270f935adfd16ad4', '9bc9d06e0efb16abde20c35ba36a2f1b', 16],
['ffe20a2c61638bb10bf943c42b4d794f', '985e237162ccfc04874664648893c241', 17],
]
df = spark.createDataFrame(data, schema=['id1', 'id2', 'grp'])
df2 = df.alias('df1')\
.join(df.alias('df2'), (F.col('df1.ID1') == F.col('df2.ID2')), 'left')\
.select(F.array_distinct(F.array(F.col('df1.ID1'), F.col('df1.ID2'), F.col('df2.ID1'), F.col('df2.ID2'))).alias('ID'), F.col('df1.grp') )
df3 = df2.select(explode('ID').alias('ID'), 'grp').dropna()
df3.groupBy('ID').agg(collect_set('grp').alias('collected_grp')).show(40, truncate=False)
My expected output is:
+------------------------------------------------------------------------------------------------------+
|ID |
+------------------------------------------------------------------------------------------------------|
|[7d693086c5b8f74cbe881166cf3c2a29, fcb907411aff4f44c599cf03d23327c0] |
|[7933546917973caa8c2898c834446415, 3ef2e38d48a9af3e096ddd3bc3816afb] |
|[8240cf1e442a97aa91d1029270728bbb, 484f25e9ab91af2c116cd788c91bdc82] |
|[8dc7dfb4466590375f1aaac7fc8cb987, 9b93e3cfc5605e74ce2ce4c9450fd622, 8919d5fd5b6fd118c1c6b691c65c9df9]|
|[8f459a7cff281bad73f604166841849e, 41f007c0cc45c228e246f1cc91145878, f6da014449e6fa82c24d002b4a27b105]|
|[99f70106443a6f3f5c69d99a49d22d01, be73ca52536d13dfea295d4fcd273fde] |
|[a9781767ca4fe8fb1282ee003d2c06ac, cb6feb2f38731fc7832545cbe2ac881b] |
|[f4901968c29e928fc7364411b03336d4, 6fa82a51f17f0bf258fe06befc661216] |
|[ffe20a2c61638bb10bf943c42b4d794f, 985e237162ccfc04874664648893c241] |
|[ff351ada316cbb0f270f935adfd16ad4, 9bc9d06e0efb16abde20c35ba36a2f1b, 7e18b452bb1e2845800a71d9431033b6]|
|[f93c0028bb26bc9b99fca1db300c2ac1, ccce888c5813025e95434d7ceedf1db3] |
+------------------------------------------------------------------------------------------------------+
You can try using
networkxpackage along withpandasto get to the result. For the following input data:I have run this logic which will first create a graph and find the connected nodes that can be attached as part of groups:
This gives me the following output: