1D Wasserstein distance in Python

3.3k Views Asked by At

The formula below is a special case of the Wasserstein distance/optimal transport when the source and target distributions, x and y (also called marginal distributions) are 1D, that is, are vectors.

enter image description here

where F^{-1} are inverse probability distribution functions of the cumulative distributions of the marginals u and v, derived from real data called x and y, both generated from the normal distribution:

import numpy as np
from numpy.random import randn
import scipy.stats as ss

n = 100
x = randn(n)
y = randn(n)

How can the integral in the formula be coded in python and scipy? I'm guessing the x and y have to be converted to ranked marginals, which are non-negative and sum to 1, while Scipy's ppf could be used to calculate the inverse F^{-1}'s?

2

There are 2 best solutions below

8
On BEST ANSWER

Note that when n gets large we have that a sorted set of n samples approaches the inverse CDF sampled at 1/n, 2/n, ..., n/n. E.g.:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.plot(norm.ppf(np.linspace(0, 1, 1000)), label="invcdf")
plt.plot(np.sort(np.random.normal(size=1000)), label="sortsample")
plt.legend()
plt.show()

plot

Also note that your integral from 0 to 1 can be approximated as a sum over 1/n, 2/n, ..., n/n.

Thus we can simply answer your question:

def W(p, u, v):
    assert len(u) == len(v)
    return np.mean(np.abs(np.sort(u) - np.sort(v))**p)**(1/p)

Note that if len(u) != len(v) you can still apply the method with linear interpolation:

def W(p, u, v):
    u = np.sort(u)
    v = np.sort(v)
    if len(u) != len(v):
        if len(u) > len(v): u, v = v, u
        us = np.linspace(0, 1, len(u))
        vs = np.linspace(0, 1, len(v))
        u = np.linalg.interp(u, us, vs)
    return np.mean(np.abs(u - v)**p)**(1/p)

An alternative method if you have prior information about the sort of distribution of your data, but not its parameters, is to find the best fitting distribution on your data (e.g. with scipy.stats.norm.fit) for both u and v and then do the integral with the desired precision. E.g.:

from scipy.stats import norm as gauss
def W_gauss(p, u, v, num_steps):
    ud = gauss(*gauss.fit(u))
    vd = gauss(*gauss.fit(v))
    z = np.linspace(0, 1, num_steps, endpoint=False) + 1/(2*num_steps)
    return np.mean(np.abs(ud.ppf(z) - vd.ppf(z))**p)**(1/p)
0
On

I guess I am a bit late but, but this is what I would do for an exact solution (using only numpy):

import numpy as np
from numpy.random import randn
n = 100
m = 80
p = 2
x = np.sort(randn(n))
y = np.sort(randn(m))
a = np.ones(n)/n
b = np.ones(m)/m
# cdfs
ca = np.cumsum(a)
cb = np.cumsum(b)

# points on which we need to evaluate the quantile functions
cba = np.sort(np.hstack([ca, cb]))
# weights for integral
h = np.diff(np.hstack([0, cba]))

# construction of first quantile function
bins = ca + 1e-10 # small tolerance to avoid rounding errors and enforce right continuity
index_qx = np.digitize(cba, bins, right=True)    # right=True becouse quantile function is 
                                                 # right continuous
qx = x[index_qx] # quantile funciton F^{-1}      

# construction of second quantile function 
bins = cb + 1e-10 
index_qy = np.digitize(cba, bins, right=True)    # right=True becouse quantile function is 
                                                 # right continuous
qy = y[index_qy] # quantile funciton G^{-1}

ot_cost = np.sum((qx - qy)**p * h)
print(ot_cost)
        

In case you are interested, here you can find a more detailed numpy based implementation of the ot problem on the real line with dual and primal solutions as well: https://github.com/gnies/1d-optimal-transport. (I am still working on it though).