wasserstein_1D returns an error for a correct expression

160 Views Asked by At

I am trying to calculate the p-Wasserstein distance using (POT : Python Optimal Transport) module. The following code returns an error, whereas the expression seems correct to me (after https://pythonot.github.io/all.html).

Any ideas why?

import  ot
import numpy as np

tab1 = np.random.normal(2,1,1000)
tab2 = np.random.normal(0,1,1000)

ot.wasserstein_1d(tab1,tab2)

output :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-31-5e5baa2dbe40> in <module>
      5 tab2 = np.random.normal(0,1,1000)
      6 
----> 7 ot.wasserstein_1d(tab1,tab2)

C:\ProgramData\Miniconda3\envs\py37_v1\lib\site-packages\ot\lp\solver_1d.py in wasserstein_1d(u_values, v_values, u_weights, v_weights, p, require_sort)
    125     u_quantiles = quantile_function(qs, u_cumweights, u_values)
    126     v_quantiles = quantile_function(qs, v_cumweights, v_values)
--> 127     qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
    128     delta = qs[1:, ...] - qs[:-1, ...]
    129     diff_quantiles = nx.abs(u_quantiles - v_quantiles)

C:\ProgramData\Miniconda3\envs\py37_v1\lib\site-packages\ot\backend.py in zero_pad(self, a, pad_width)
   1026 
   1027     def zero_pad(self, a, pad_width):
-> 1028         return np.pad(a, pad_width)
   1029 
   1030     def argmax(self, a, axis=None):

TypeError: pad() missing 1 required positional argument: 'mode'
2

There are 2 best solutions below

0
On

You are using a numpy version that conflicts with the version that is required by the ot package. Update the numpy version accordingly and it will work.

0
On

I don't know what is up with your environment.

I have Windows 10 Pro x64, Python x64 3.10.8, POT 0.8.2, numpy 1.23.5 and everything works

Your code example

tab1 = np.random.normal(2,1,1000)
tab2 = np.random.normal(0,1,1000)

q = ot.wasserstein_1d(tab1,tab2)
print(q)

prints

2.0226588606978444

or some other value around 2 (depending on random sample)

Looks about right to me