How to use polars to calculate the cross and hash of two columns

159 Views Asked by At

I have a polars Dataframe like df, I want to calculate the cross product of 'A' and 'B' in each row. The ground truth is like df2. How can I do this efficiently with polars?

Step 1:

import polars as pl
import itertools

df = pl.DataFrame({
    'A': [[1,1],[2,2]],
    'B': [[3,4],[5,6]]}
)

print("df:", df)

out:

df: shape: (2, 2)
┌───────────┬───────────┐
│ A         ┆ B         │
│ ---       ┆ ---       │
│ list[i64] ┆ list[i64] │
╞═══════════╪═══════════╡
│ [1, 1]    ┆ [3, 4]    │
│ [2, 2]    ┆ [5, 6]    │
└───────────┴───────────┘

Step2:

row1_prod = list(itertools.product([1, 1], [3, 4]))
row2_prod = list(itertools.product([2, 2], [5, 6]))
print("row1_prod: ", row1_prod)
print("row2_prod: ", row2_prod)

out:

row1_prod:  [(1, 3), (1, 4), (1, 3), (1, 4)]
row2_prod:  [(2, 5), (2, 6), (2, 5), (2, 6)]

Step3:

C_row1 = [hash(e) % 100 for e in row1_prod]
C_row2 = [hash(e) % 100 for e in row2_prod]

print("C_row1:", C_row1)
print("C_row2:", C_row2)

out:

C_row1: [80, 14, 80, 14]
C_row2: [75, 72, 75, 72]

Step4:

df2 = df.with_columns(
    pl.Series("C", [C_row1, C_row2])
)
print("df2:", df2)

out:

df2: shape: (2, 3)
┌───────────┬───────────┬────────────────┐
│ A         ┆ B         ┆ C              │
│ ---       ┆ ---       ┆ ---            │
│ list[i64] ┆ list[i64] ┆ list[i64]      │
╞═══════════╪═══════════╪════════════════╡
│ [1, 1]    ┆ [3, 4]    ┆ [80, 14, … 14] │
│ [2, 2]    ┆ [5, 6]    ┆ [75, 72, … 72] │
└───────────┴───────────┴────────────────┘

I try to use the apply method in polars, but it's more than 20x slower. How can I speed it up? Thanks in advance.

1

There are 1 best solutions below

2
On

As per the updated "jagged" example from the comments:

df = pl.DataFrame({ 
   "A": [[1,2,2], [3,4]],     
   "B": [[5],[7,8]]
})
row1_prod = list(itertools.product([1, 2, 2], [5]))
row2_prod = list(itertools.product([3, 4], [7, 8]))

print("row1_prod: ", row1_prod)
print("row2_prod: ", row2_prod)

row1_prod:  [(1, 5), (2, 5), (2, 5)]
row2_prod:  [(3, 7), (3, 8), (4, 7), (4, 8)]

That looks equivalent to running .explode() on each column individually:

(df.with_row_count()
   .explode("A")
   .explode("B")
)
shape: (7, 3)
┌────────┬─────┬─────┐
│ row_nr ┆ A   ┆ B   │
│ ---    ┆ --- ┆ --- │
│ u32    ┆ i64 ┆ i64 │
╞════════╪═════╪═════╡
│ 0      ┆ 1   ┆ 5   │
│ 0      ┆ 2   ┆ 5   │
│ 0      ┆ 2   ┆ 5   │
│ 1      ┆ 3   ┆ 7   │
│ 1      ┆ 3   ┆ 8   │
│ 1      ┆ 4   ┆ 7   │
│ 1      ┆ 4   ┆ 8   │
└────────┴─────┴─────┘

We can then combine A and B into a single "item" with pl.struct() and .hash() the result.

Following that with .group_by().agg() to get back a single list per row.

hashed = (
   df.with_row_count()
     .explode("A")
     .explode("B")
     .with_columns(hash = pl.struct("A", "B").hash() % 100)
     .group_by("row_nr", maintain_order=True)
     .agg("hash")
)
shape: (2, 2)
┌────────┬──────────────────┐
│ row_nr ┆ hash             │
│ ---    ┆ ---              │
│ u32    ┆ list[u64]        │
╞════════╪══════════════════╡
│ 0      ┆ [24, 73, 73]     │
│ 1      ┆ [53, 66, 69, 86] │
└────────┴──────────────────┘

As we have used maintain_order=True in the .group_by() the row order is guaranteed, meaning we can assign the result directly via .with_columns()

df.with_columns(hashed.select("hash"))
shape: (2, 3)
┌───────────┬───────────┬──────────────────┐
│ A         ┆ B         ┆ hash             │
│ ---       ┆ ---       ┆ ---              │
│ list[i64] ┆ list[i64] ┆ list[u64]        │
╞═══════════╪═══════════╪══════════════════╡
│ [1, 2, 2] ┆ [5]       ┆ [24, 73, 73]     │
│ [3, 4]    ┆ [7, 8]    ┆ [53, 66, 69, 86] │
└───────────┴───────────┴──────────────────┘
  • .groupby() was renamed to .group_by() in 0.19.0