Rewrite UDF to pandas UDF Pyspark

242 Views Asked by At

I have a dataframe:

import pyspark.sql.functions as F

sdf1 = spark.createDataFrame(
    [
        (2022, 1, ["apple", "edible"]),
        (2022, 1, ["edible", "fruit"]),
        (2022, 1, ["orange", "sweet"]),
        (2022, 4, ["flowering ", "plant"]),
        (2022, 3, ["green", "kiwi"]),
        (2022, 3, ["kiwi", "fruit"]),
        (2022, 3, ["fruit", "popular"]),
        (2022, 3, ["yellow", "lemon"]),
    ],
    [
        "year",
        "id",
        "bigram",
    ],
)
sdf1.show(truncate=False)

    +----+---+-------------------+
    |year|id |bigram             |
    +----+---+-------------------+
    |2022|1  |[apple, edible]    |
    |2022|1  |[edible, fruit]    |
    |2022|1  |[orange, sweet]    |
    |2022|4  |[flowering , plant]|
    |2022|3  |[green, kiwi]      |
    |2022|3  |[kiwi, fruit]      |
    |2022|3  |[fruit, popular]   |
    |2022|3  |[yellow, lemon]    |
    +----+---+-------------------+

And i wrote a function that returns bigrams with the same last words in n-grams.I apply this function separately to the column.

from networkx import DiGraph, dfs_labeled_edges

# Grouping
sdf = (
    sdf1.groupby("year", "id")
    .agg(F.collect_set("bigram").alias("collect_bigramm"))
    .withColumn("size", F.size("collect_bigramm"))
)

data_collect = sdf.collect()


@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
    graph = DiGraph()

    for row in data_collect:
        if row["size"] > 1:
            for i, lst1 in enumerate(lst):
                while i < len(lst) - 1:
                    lst2 = lst[i + 1]
                    if lst1[0] == lst2[1]:
                        graph.add_edge(lst2[0], lst2[1])
                        graph.add_edge(lst1[0], lst1[1])
                    elif lst1[1] == lst2[0]:
                        graph.add_edge(lst1[0], lst1[1])
                        graph.add_edge(lst2[0], lst2[1])
                    i = i + 1

            gen = dfs_labeled_edges(graph)
            lst_tmp = []
            lst_res = []
            f = 0
            for g in list(gen):
                if (g[2] == "forward") and (g[0] != g[1]):
                    f = 1
                    lst_tmp.append(g[0])
                    lst_tmp.append(g[1])

                if g[2] == "nontree":
                    continue
                if g[2] == "reverse":
                    if f == 1:
                        lst_res.append(lst_tmp.copy())
                    f = 0
                    if g[0] in lst_tmp:
                        lst_tmp.remove(g[0])
                    if g[1] in lst_tmp:
                        lst_tmp.remove(g[1])

            if lst_res != []:
                lst_res = [
                    ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
                ]
            if lst_res == []:
                lst_res = None
            return lst_res


sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)

Output:

+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm                                                          |size|new_col                      |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4  |[[flowering , plant]]                                            |1   |null                         |
|2022|1  |[[edible, fruit], [orange, sweet], [apple, edible]]              |3   |[apple, edible, fruit]       |
|2022|3  |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4   |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+

But now i want to use the pandas udf. I would like to first groupby and get the collect_bigramm column in the function. And thus leave all the columns in the dataframe, but also add a new one, which is the lst_res array in the function.


schema2 = StructType(
    [
        StructField("year", IntegerType(), True),
        StructField("id", IntegerType(), True),
        StructField("bigram", ArrayType(StringType(), True), True),
        StructField("new_col", ArrayType(StringType(), True), True),
        StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
    ]
)


@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):

    graph = DiGraph()
    for index, row in df.iterrows():
        # Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
        ...

    return df


sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
1

There are 1 best solutions below

0
On BEST ANSWER
  1. You don't want to run groupBy twice (one for sdf1 and one for pandas_udf), it'd simply kill the idea of "grouping a list of records then vectorize it then send to worker" of pandas_udf. You'd want to do something like this instead sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)

  2. Your UDF is now a "Panda UDF", which is literally just a Python function, take one Pandas DF and return another Pandas UDF. With that meaning, you can even run that function without Spark. The trick here is just how to form your dataframe to feed with what you need. Check the running code below, I kept most of your networkx code, just fix a little from the input and output.

def myfunc(pdf):
    pdf = (pdf
        .groupby(['year', 'id'])['bigram']
        .agg(list=list, len=len) # you might want to fix the list here to set
        .reset_index()
        .rename(columns={
            'list': 'collect_bigram',
            'len': 'size',
        })
    )

    graph = DiGraph()
    if pdf['size'][0] > 1:
        lst = pdf['collect_bigram'][0]
        for i, lst1 in enumerate(lst):
        ... # same as original code
        if lst_res == []:
            lst_res = None
        pdf['new_col'] = [lst_res]
    else:
        pdf['new_col'] = None
    return pdf