Create column using Spark pandas_udf, with dynamic number of input columns

959 Views Asked by At

I have this df:

df = spark.createDataFrame(
    [('row_a', 5.0, 0.0, 11.0),
     ('row_b', 3394.0, 0.0, 4543.0),
     ('row_c', 136111.0, 0.0, 219255.0),
     ('row_d', 0.0, 0.0, 0.0),
     ('row_e', 0.0, 0.0, 0.0),
     ('row_f', 42.0, 0.0, 54.0)],
    ['value', 'col_a', 'col_b', 'col_c']
)

I would like to use .quantile(0.25, axis=1) from Pandas which would add one column:

import pandas as pd
pdf = df.toPandas()
pdf['25%'] = pdf.quantile(0.25, axis=1)
print(pdf)
#    value     col_a  col_b     col_c      25%
# 0  row_a       5.0    0.0      11.0      2.5
# 1  row_b    3394.0    0.0    4543.0   1697.0
# 2  row_c  136111.0    0.0  219255.0  68055.5
# 3  row_d       0.0    0.0       0.0      0.0
# 4  row_e       0.0    0.0       0.0      0.0
# 5  row_f      42.0    0.0      54.0     21.0

Performance to me is important, so I assume pandas_udf from pyspark.sql.functions could do it in a more optimized way. But I struggle to make a performant and useful function. This is my best attempt:

from pyspark.sql import functions as F
import pandas as pd
@F.pandas_udf('double')
def quartile1_on_axis1(a: pd.Series, b: pd.Series, c: pd.Series) -> pd.Series:
    pdf = pd.DataFrame({'a':a, 'b':b, 'c':c})
    return pdf.quantile(0.25, axis=1)

df = df.withColumn('25%', quartile1_on_axis1('col_a', 'col_b', 'col_c'))
  1. I don't like that I need an argument for every column and later in the function addressing those arguments separately to create a df. All of those columns serve the same purpose, so IMHO there should be a way to address them all together, something like in this pseudocode:

    def quartile1_on_axis1(*cols) -> pd.Series:
        pdf = pd.DataFrame(cols)
    

    This way I could use this function for any number of columns.

  2. Is it necessary to create a pd.Dataframe inside the UDF? To me this seems the same as without UDF (Spark df -> Pandas df -> Spark df), as shown above. Without UDF it's even shorter. Should I really try to make it work with pandas_udf performance-wise? I think pandas_udf was designed specifically for this kind of purpose...

4

There are 4 best solutions below

0
On BEST ANSWER

You can pass a single struct column instead of using multiple columns like this:

@F.pandas_udf('double')
def quartile1_on_axis1(s: pd.DataFrame) -> pd.Series:
    return s.quantile(0.25, axis=1)


cols = ['col_a', 'col_b', 'col_c']

df = df.withColumn('25%', quartile1_on_axis1(F.struct(*cols)))
df.show()

# +-----+--------+-----+--------+-------+
# |value|   col_a|col_b|   col_c|    25%|
# +-----+--------+-----+--------+-------+
# |row_a|     5.0|  0.0|    11.0|    2.5|
# |row_b|  3394.0|  0.0|  4543.0| 1697.0|
# |row_c|136111.0|  0.0|219255.0|68055.5|
# |row_d|     0.0|  0.0|     0.0|    0.0|
# |row_e|     0.0|  0.0|     0.0|    0.0|
# |row_f|    42.0|  0.0|    54.0|   21.0|
# +-----+--------+-----+--------+-------+

pyspark.sql.functions.pandas_udf

Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrame should be used for its input or output type hint instead when the input or output column is of pyspark.sql.types.StructType.

0
On

The following seems to do what's required, but instead of pandas_udf it uses a regular udf. It would be great if I could employ pandas_udf in a similar way.

from pyspark.sql import functions as F
import numpy as np

@F.udf('double')
def lower_quart(*cols):
    return float(np.quantile(cols, 0.25))
df = df.withColumn('25%', lower_quart('col_a', 'col_b', 'col_c'))

df.show()
#+-----+--------+-----+--------+-------+
#|value|   col_a|col_b|   col_c|    25%|
#+-----+--------+-----+--------+-------+
#|row_a|     5.0|  0.0|    11.0|    2.5|
#|row_b|  3394.0|  0.0|  4543.0| 1697.0|
#|row_c|136111.0|  0.0|219255.0|68055.5|
#|row_d|     0.0|  0.0|     0.0|    0.0|
#|row_e|     0.0|  0.0|     0.0|    0.0|
#|row_f|    42.0|  0.0|    54.0|   21.0|
#+-----+--------+-----+--------+-------+
0
On

The udf approach will get you the result you need, and is definitely the most straightforward. However, if performance really is top priority you can create your own native Spark implementation for quantile. The basics can be coded quite easily, if you want to use any of the other pandas parameters you'll need to tweak it yourself.

Note: this is taken from the pandas API docs for interpolation='linear'. If you intent to use it, please test the performance and verify the results yourself on large datasets.

import math
from pyspark.sql import functions as f

def quantile(q, cols):
    if q < 0 or q > 1:
        raise ValueError("Parameter q should be 0 <= q <= 1")

    if not cols:
        raise ValueError("List of columns should be provided")

    idx = (len(cols) - 1) * q
    i = math.floor(idx)
    j = math.ceil(idx)
    fraction = idx - i

    arr = f.array_sort(f.array(*cols))

    return arr.getItem(i) + (arr.getItem(j) - arr.getItem(i)) * fraction


columns = ['col_a', 'col_b', 'col_c']

df.withColumn('0.25%', quantile(0.25, columns)).show()

+-----+--------+-----+--------+-----+-------+
|value|   col_a|col_b|   col_c|col_d|  0.25%|
+-----+--------+-----+--------+-----+-------+
|row_a|     5.0|  0.0|    11.0|    1|    2.5|
|row_b|  3394.0|  0.0|  4543.0|    1| 1697.0|
|row_c|136111.0|  0.0|219255.0|    1|68055.5|
|row_d|     0.0|  0.0|     0.0|    1|    0.0|
|row_e|     0.0|  0.0|     0.0|    1|    0.0|
|row_f|    42.0|  0.0|    54.0|    1|   21.0|
+-----+--------+-----+--------+-----+-------+

As a side note, there is also the pandas API on spark, however axis=1 is not (yet) implemented. Potentially this will be added in the future.

df.to_pandas_on_spark().quantile(0.25, axis=1)

NotImplementedError: axis should be either 0 or "index" currently.
4
On

I would use GroupedData. Because this requires you pass the df's schema, add a column with the required datatype and get the schema. Pass that schema when required. Code below;

#Generate new schema by adding new column

sch =df.withColumn('quantile25',lit(110.5)).schema

#udf
def quartile1_on_axis1(pdf):
  
  pdf =pdf.assign(quantile25=pdf.quantile(0.25, axis=1))
 
  return pdf


 #apply udf 


df.groupby('value').applyInPandas(quartile1_on_axis1, schema=sch).show()


#outcome
+-----+--------+-----+--------+----------+
|value|   col_a|col_b|   col_c|quantile25|
+-----+--------+-----+--------+----------+
|row_a|     5.0|  0.0|    11.0|       2.5|
|row_b|  3394.0|  0.0|  4543.0|    1697.0|
|row_c|136111.0|  0.0|219255.0|   68055.5|
|row_d|     0.0|  0.0|     0.0|       0.0|
|row_e|     0.0|  0.0|     0.0|       0.0|
|row_f|    42.0|  0.0|    54.0|      21.0|
+-----+--------+-----+--------+----------+

You also could use numpy in a udf to get this done. If you do not want to list all columns slice them(columns) by index.

quartile1_on_axis1=udf(lambda x: float(np.quantile(x, 0.25)),FloatType())

df.withColumn("0.25%", quartile1_on_axis1(array(df.columns[1:]))).show(truncate=False)

+-----+--------+-----+--------+-------+
|value|col_a   |col_b|col_c   |0.25%  |
+-----+--------+-----+--------+-------+
|row_a|5.0     |0.0  |11.0    |2.5    |
|row_b|3394.0  |0.0  |4543.0  |1697.0 |
|row_c|136111.0|0.0  |219255.0|68055.5|
|row_d|0.0     |0.0  |0.0     |0.0    |
|row_e|0.0     |0.0  |0.0     |0.0    |
|row_f|42.0    |0.0  |54.0    |21.0   |
+-----+--------+-----+--------+-------+