Aggregation in Polars window functions - how to select the top value based on an aggregation from other column

63 Views Asked by At

I have a large dataset on ocean freight that includes columns for bol, voyage_id, carrier, and total containers (teus), similar to this:

lf = pl.LazyFrame({
    'bol_id':(1,2,3,4,5,6,7,8,9),
    'voyage_id':(1,1,1,2,2,2,3,3,3),
    'carrier_scac':('mscu', 'mscu', 'hpld', 'hpld', 'hpld', 'hpld', 'ever', 'mscu', 'ever'),
    'teus':(20, 40, 5, 10, 25, 20, 5, 45, 5)
})
print(lf.collect())
┌────────┬───────────┬──────────────┬──────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus │
│ ---    ┆ ---       ┆ ---          ┆ ---  │
│ i64    ┆ i64       ┆ str          ┆ i64  │
╞════════╪═══════════╪══════════════╪══════╡
│ 1      ┆ 1         ┆ mscu         ┆ 20   │
│ 2      ┆ 1         ┆ mscu         ┆ 40   │
│ 3      ┆ 1         ┆ hpld         ┆ 5    │
│ 4      ┆ 2         ┆ hpld         ┆ 10   │
│ 5      ┆ 2         ┆ hpld         ┆ 25   │
│ 6      ┆ 2         ┆ hpld         ┆ 20   │
│ 7      ┆ 3         ┆ ever         ┆ 5    │
│ 8      ┆ 3         ┆ mscu         ┆ 45   │
│ 9      ┆ 3         ┆ ever         ┆ 5    │
└────────┴───────────┴──────────────┴──────┘

For each voyage, I want to get the carrier with the highest sum of teus. I can do this by a group_by followed by a join, but I'd like to do this with a window function and can't quite figure out the syntax/logic in Polars (0.20).

Current working function:

def add_primary_carrier(lf):
    lf2 = (
        lf
        #select relevant cols
        .select('voyage_id', 'carrier_scac', 'teus')
        #ignore bols with missing data
        .drop_nulls()
        #sum up TEUs by voyage and carrier
        .group_by('voyage_id', 'carrier_scac')
        .agg(pl.col('teus').sum().alias('sum_teus'))
        #choose the carrier with the most TEUs on each voyage
        .sort('sum_teus', descending=True)
        .group_by('voyage_id')
        .agg(pl.col('carrier_scac').first().alias('primary_scac'))
    )
    lf = (
        #add primary scac column to main lf
        lf.join(lf2, how='left', on='voyage_id')
    )

But it seems a window function would be a lot cleaner (and perhaps less resource-intensive). Something like:

def add_primary_carrier_window(lf):
    lf = (
        lf.with_columns(
            pl.col('carrier_scac')
            .sort_by(pl.col('teus').sum().over('carrier_scac'), descending=True)
            .drop_nulls().first()
            .over('voyage_id')
            .alias('primary_scac')
        )
    )
    return lf

But that function throws a "window expression not allowed in aggregation" OperationError.

Thanks in advance for the help!

Expected output:

┌────────┬───────────┬──────────────┬──────┬──────────────┬──────────────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus ┆ primary_scac ┆ shared_cargo │
│ ---    ┆ ---       ┆ ---          ┆ ---  ┆ ---          ┆ ---          │
│ i64    ┆ i64       ┆ str          ┆ i64  ┆ str          ┆ bool         │
╞════════╪═══════════╪══════════════╪══════╪══════════════╪══════════════╡
│ 1      ┆ 1         ┆ mscu         ┆ 20   ┆ mscu         ┆ false        │
│ 2      ┆ 1         ┆ mscu         ┆ 40   ┆ mscu         ┆ false        │
│ 3      ┆ 1         ┆ hpld         ┆ 5    ┆ mscu         ┆ true         │
│ 4      ┆ 2         ┆ hpld         ┆ 10   ┆ hpld         ┆ false        │
│ 5      ┆ 2         ┆ hpld         ┆ 25   ┆ hpld         ┆ false        │
│ 6      ┆ 2         ┆ hpld         ┆ 20   ┆ hpld         ┆ false        │
│ 7      ┆ 3         ┆ ever         ┆ 5    ┆ mscu         ┆ true         │
│ 8      ┆ 3         ┆ mscu         ┆ 45   ┆ mscu         ┆ false        │
│ 9      ┆ 3         ┆ ever         ┆ 5    ┆ mscu         ┆ true         │
└────────┴───────────┴──────────────┴──────┴──────────────┴──────────────┘
2

There are 2 best solutions below

1
jqurious On BEST ANSWER

There are a few issues on the tracker regarding it e.g. https://github.com/pola-rs/polars/issues/14361

You basically have to create a column from each .over "aggregation" in separate .with_columns calls as they cannot be "nested".

(df.with_columns( 
    pl.col('teus')
      .sum()
      .over('voyage_id', 'carrier_scac')
      .alias('sum_teus')
   )
   .with_columns(
      pl.col('carrier_scac') 
        .sort_by('sum_teus', descending=True)
        .first()
        .over('voyage_id')
        .alias('primary_scac')
   )
)
shape: (9, 6)
┌────────┬───────────┬──────────────┬──────┬──────────┬──────────────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus ┆ sum_teus ┆ primary_scac │
│ ---    ┆ ---       ┆ ---          ┆ ---  ┆ ---      ┆ ---          │
│ i64    ┆ i64       ┆ str          ┆ i64  ┆ i64      ┆ str          │
╞════════╪═══════════╪══════════════╪══════╪══════════╪══════════════╡
│ 1      ┆ 1         ┆ mscu         ┆ 20   ┆ 60       ┆ mscu         │
│ 2      ┆ 1         ┆ mscu         ┆ 40   ┆ 60       ┆ mscu         │
│ 3      ┆ 1         ┆ hpld         ┆ 5    ┆ 5        ┆ mscu         │
│ 4      ┆ 2         ┆ hpld         ┆ 10   ┆ 55       ┆ hpld         │
│ 5      ┆ 2         ┆ hpld         ┆ 25   ┆ 55       ┆ hpld         │
│ 6      ┆ 2         ┆ hpld         ┆ 20   ┆ 55       ┆ hpld         │
│ 7      ┆ 3         ┆ ever         ┆ 5    ┆ 10       ┆ mscu         │
│ 8      ┆ 3         ┆ mscu         ┆ 45   ┆ 45       ┆ mscu         │
│ 9      ┆ 3         ┆ ever         ┆ 5    ┆ 10       ┆ mscu         │
└────────┴───────────┴──────────────┴──────┴──────────┴──────────────┘
1
BallpointBen On

Actually this is possible using a single expression and only group_by and over. No sorting required.

(
    lf.with_columns(teus_sum=pl.col("teus").sum().over("voyage_id", "carrier_scac"))
    .group_by("voyage_id", maintain_order=True)
    .agg(
        pl.col("carrier_scac").get(pl.col("teus_sum").arg_max()),
        pl.col("teus_sum").max(),
    )
    .collect()
)
shape: (3, 3)
┌───────────┬──────────────┬──────────┐
│ voyage_id ┆ carrier_scac ┆ teus_sum │
│ ---       ┆ ---          ┆ ---      │
│ i64       ┆ str          ┆ i64      │
╞═══════════╪══════════════╪══════════╡
│ 1         ┆ mscu         ┆ 60       │
│ 2         ┆ hpld         ┆ 55       │
│ 3         ┆ mscu         ┆ 45       │
└───────────┴──────────────┴──────────┘

Or if you want the data in the same shape as the original, you can use over instead of group_by:

(
    lf.with_columns(teus_sum=pl.col("teus").sum().over("voyage_id", "carrier_scac"))
    .with_columns(
        pl.col("carrier_scac").get(pl.col("teus_sum").arg_max()).over("voyage_id"),
        pl.col("teus_sum").max().over("voyage_id"),
    )
    .drop("teus")
    .collect()
)
shape: (9, 4)
┌────────┬───────────┬──────────────┬──────────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus_sum │
│ ---    ┆ ---       ┆ ---          ┆ ---      │
│ i64    ┆ i64       ┆ str          ┆ i64      │
╞════════╪═══════════╪══════════════╪══════════╡
│ 1      ┆ 1         ┆ mscu         ┆ 60       │
│ 2      ┆ 1         ┆ mscu         ┆ 60       │
│ 3      ┆ 1         ┆ mscu         ┆ 60       │
│ 4      ┆ 2         ┆ hpld         ┆ 55       │
│ 5      ┆ 2         ┆ hpld         ┆ 55       │
│ 6      ┆ 2         ┆ hpld         ┆ 55       │
│ 7      ┆ 3         ┆ mscu         ┆ 45       │
│ 8      ┆ 3         ┆ mscu         ┆ 45       │
│ 9      ┆ 3         ┆ mscu         ┆ 45       │
└────────┴───────────┴──────────────┴──────────┘