Here is my schema:
root
|-- embedding_init: array (nullable = true)
| |-- element: double (containsNull = true)
|-- embeddings: array (nullable = false)
| |-- element: array (containsNull = false)
| | |-- element: double (containsNull = true)
I'm looking to create a udf to calculate the cosine similarity between embedding_init
and all of the embeddings within embeddings
Here is my attempt:
@pandas_udf(T.ArrayType(T.DoubleType()), PandasUDFType.SCALAR)
def cosine_sim(embedding_init, embeddings):
embedding_init = np.array([embedding_init])
embeddings = np.array(embeddings)
sims = cosine_similarity(embedding_init, embeddings)[0]
return sims
df.withColumn("cosine_similarity", cosine_sim(df.embedding_init, df.embeddings))
When I do this, I constantly get the following error:
An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
File "<ipython-input-90-9bd905bd9575>", line 16, in cosine_sim
File "/artifacts/virtualenv/starscream_default/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 1246, in cosine_similarity
X, Y = check_pairwise_arrays(X, Y)
File "/artifacts/virtualenv/starscream_default/lib/python3.7/site-packages/sklearn/metrics/pairwise.py", line 153, in check_pairwise_arrays
estimator=estimator,
File "/artifacts/virtualenv/starscream_default/lib/python3.7/site-packages/sklearn/utils/validation.py", line 742, in check_array
) from complex_warning
ValueError: setting an array element with a sequence.
More details about the data:
- There are 9 arrays within embeddings and each of them are of size 512.
From Documentation:
Looks like shape of arrays - embedding_init and embeddings shoule be same.
E.g
If you execute your function by using below arguments it works:
But with arguments where arrays shapes are not matching, functions throws error.
Make sure your dataframe columns which are being used as arguments for UDF has same shapes.