Looping through an array of float column in a Pyspark DataFrame to find which values pass through a condition

110 Views Asked by At

I am using Pyspark 2.4.0.

I have a float list as such :

[0.6067762380157111,
 0.4595708424660512,
 0.20093402090021173,
 0.5736288504883545,
 0.46593043507957116,
 0.5734057882715504,
 0.6940067723754003,
 0.30921836906829625,
 0.768595041322314...]

that are the result of phonetic transcription, from a list of words.

I have multiple Pyspark DataFrames as such :

+------------------------+--------------------------+---------------------------------------------+
|com                     |split                     |phoned                                       |
+------------------------+--------------------------+---------------------------------------------+
|sans option             |[sans, option]            |[0.6832268970698724, 0.6248699945979845]     |
|                        |[]                        |[]                                           |
|fermer l hiver          |[fermer, l, hiver]        |[0.3154196179245309, 0.5, 0.3828829842720629]|
+------------------------+--------------------------+---------------------------------------------+

The idea would be to find, for every value in "phoned" in all of my Pyspark DataFrames if one value in the array is "close" to a value in the list (ie with a given threshold, that I can change).

So for every DataFrame, to "loop" through all the values in the column "phoned", loop through the given array, get the difference between the value and every element of the list and when a difference is below the threshold, get the given value in another column. If possible, I would like to get all the values that are the given threshold.

For example, if my list of words would be this one :

["sans",
 "opssion",
 "test",
 "ferme",...]

The phonetic transcription version would be :

[0.6832268970698724,
0.625416705239052,
0.7390120210368145,
0.3154165838771569,...]

And the result I would like, for a threshold of 0.01,

+------------------------+--------------------------+---------------------------------------------+--------------+
|com                     |split                     |phoned                                       |result        |
+------------------------+--------------------------+---------------------------------------------+--------------+
|sans option             |[sans, option]            |[0.6832268970698724, 0.6248699945979845]     |[sans, opssion]
|                        |[]                        |[]                                           |[]            |
|fermer l hiver          |[fermer, l, hiver]        |[0.3154196179245309, 0.5, 0.3828829842720629]|[ferme]       |
+------------------------+--------------------------+---------------------------------------------+--------------+

I've tried my way around with some UDF, but I didn't find a solution that gives this kind of result. I have a dozen of DataFrames, some with multiple columns "com" and with ~1 million record for some of them, so I can't deal with it in Pandas.

Thanks !

1

There are 1 best solutions below

3
Arud Seka Berne S On BEST ANSWER

Your DataFrame(df_1)

+--------------+------------------+---------------------------------------------+
|com           |split             |phoned                                       |
+--------------+------------------+---------------------------------------------+
|sans option   |[sans, option]    |[0.6832268970698724, 0.6248699945979845]     |
|              |[]                |[]                                           |
|fermer l hiver|[fermer, l, hiver]|[0.3154196179245309, 0.5, 0.3828829842720629]|
+--------------+------------------+---------------------------------------------+

Your constants:

threshold = 0.01

match = [0.6067762380157111,0.4595708424660512,
         0.20093402090021173,0.5736288504883545,
         0.46593043507957116,0.5734057882715504,
         0.6940067723754003,0.30921836906829625,0.768595041322314]

Created UDF

def between_threshold(element, threshold): 
    if element is None:
        return False
    return any([abs(element - match_element) <= threshold for match_element in match])

Importing necessary packages:

from pyspark.sql.functions import lit, col, udf, row_number, posexplode_outer, collect_list
  1. Create an unique row identifier
row_window_spec = Window.orderBy(lit(1))

df_2 = df_1.withColumn("row_num", row_number().over(row_window_spec))
  1. Find if the elements in phoned is within the threshold or not UDF
df_3 = df_2.select("row_num", posexplode_outer("split").alias("index", "split"), "phoned")

df_4 = df_3.withColumn("phoned", col("phoned")[col("index")])

df_5 = df_4.withColumn("between_threshold", udf(between_threshold_udf("phoned", lit(threshold))))
  1. Filter and Join with the original DataFrame
df_6 = df_5.filter(col("between_threshold") == True) \
        .groupBy("row_num") \
        .agg(collect_list("split").alias("result"))

df_2.join(df_6, "row_num", "left").drop("row_num").show(truncate=False)

Output

+--------------+------------------+---------------------------------------------+--------+
|com           |split             |phoned                                       |result  |
+--------------+------------------+---------------------------------------------+--------+
|sans option   |[sans, option]    |[0.6832268970698724, 0.6248699945979845]     |null    |
|              |[]                |[]                                           |null    |
|fermer l hiver|[fermer, l, hiver]|[0.3154196179245309, 0.5, 0.3828829842720629]|[fermer]|
+--------------+------------------+---------------------------------------------+--------+

Note: In this dataset only fermer matches with the threshold 0.01