Exponential Moving Average (EMA) calculations in Polars dataframe

524 Views Asked by At

I have the following list of 20 values:

values = [143.15,143.1,143.06,143.01,143.03,143.09,143.14,143.18,143.2,143.2,143.2,143.31,143.38,143.35,143.34,143.25,143.33,143.3,143.33,143.36]

In order to find the Exponential Moving Average, across a span of 9 values, I can do the following in Python:

def calculate_ema(values, periods, smoothing=2):
    ema = [sum(values[:periods]) / periods]
    
    for price in values[periods:]:
        ema.append((price * (smoothing / (1 + periods))) + ema[-1] * (1 - (smoothing / (1 + periods))))
    return ema

ema_9 = calculate_ema(values, periods=9)
[143.10666666666668,
 143.12533333333334,
 143.14026666666666,
 143.17421333333334,
 143.21537066666667,
 143.24229653333333,
 143.26183722666667,
 143.25946978133334,
 143.27357582506667,
 143.27886066005334,
 143.28908852804267,
 143.30327082243414]

The resulting list of EMA values is 12 items long, the first value [0] corresponding to the 9th [8] value from values.

Using Pandas and TA-Lib, I can perform the following:

import pandas as pd
import talib as ta

df_pan = pd.DataFrame(
    {
        'value': values
    }
)

df_pan['ema_9'] = ta.EMA(df_pan['value'], timeperiod=9)

df_pan
    value   ema_9
0   143.15  NaN
1   143.10  NaN
2   143.06  NaN
3   143.01  NaN
4   143.03  NaN
5   143.09  NaN
6   143.14  NaN
7   143.18  NaN
8   143.20  143.106667
9   143.20  143.125333
10  143.20  143.140267
11  143.31  143.174213
12  143.38  143.215371
13  143.35  143.242297
14  143.34  143.261837
15  143.25  143.259470
16  143.33  143.273576
17  143.30  143.278861
18  143.33  143.289089
19  143.36  143.303271

The Pandas / TA-Lib output corresponds with that of my Python function.

However, when I try to replicate this using funtionality purely in Polars:

import polars as pl

df = (
    pl.DataFrame(
        {
            'value': values
        }
    )
    .with_columns(
        pl.col('value').ewm_mean(span=9, min_periods=9,).alias('ema_9')
    )
)

df

I get different values:

value   ema_9
f64 f64
143.15  null
143.1   null
143.06  null
143.01  null
143.03  null
143.09  null
143.14  null
143.18  null
143.2   143.128695
143.2   143.144672
143.2   143.156777
143.31  143.189683
143.38  143.229961
143.35  143.255073
143.34  143.272678
143.25  143.268011
143.33  143.280694
143.3   143.284626
143.33  143.293834
143.36  143.307221

Can anyone please explain what adjustments I need to make to my Polars code in order get the expected results?

1

There are 1 best solutions below

0
Wayoshi On BEST ANSWER

Two things here:

  • Reading the ewm_mean docs closely, you want adjust=False (default is True).
  • min_periods is still doing the calculations as if you didn't skip any values, it just replaces those calculated values with null up to the min_periodsth row, so to speak. Try removing min_periods and see how the tail values don't change at all.

To actually change the calculation (starting with the mean of the first min_periods values), we can do a a pl.when with cumcount (a handy way to get the row index of a value). The calculations will all still be done under the hood, but the ewm_mean will stay at this constant value, of course, until row 9, and min_periods=9 will null them out in the end.

All together:

df.with_columns(
    pl.when(pl.col('value').cumcount() < 9)
    .then(pl.col('value').head(9).mean())
    .otherwise(pl.col('value'))
    .ewm_mean(span=9, min_periods=9, adjust=False)
    .alias('ema_9')
)
shape: (20, 2)
┌────────┬────────────┐
│ value  ┆ ema_9      │
│ ---    ┆ ---        │
│ f64    ┆ f64        │
╞════════╪════════════╡
│ 143.15 ┆ null       │
│ 143.1  ┆ null       │
│ 143.06 ┆ null       │
│ 143.01 ┆ null       │
│ 143.03 ┆ null       │
│ 143.09 ┆ null       │
│ 143.14 ┆ null       │
│ 143.18 ┆ null       │
│ 143.2  ┆ 143.106667 │
│ 143.2  ┆ 143.125333 │
│ 143.2  ┆ 143.140267 │
│ 143.31 ┆ 143.174213 │
│ 143.38 ┆ 143.215371 │
│ 143.35 ┆ 143.242297 │
│ 143.34 ┆ 143.261837 │
│ 143.25 ┆ 143.25947  │
│ 143.33 ┆ 143.273576 │
│ 143.3  ┆ 143.278861 │
│ 143.33 ┆ 143.289089 │
│ 143.36 ┆ 143.303271 │
└────────┴────────────┘