scipy.stats.wasserstein_distance implementation

727 Views Asked by At

I am trying to understand the implementation that is used in scipy.stats.wasserstein_distance

for p=1 and no weights, with u_values, v_values the two 1-D distributions, the code comes down to

u_sorter = np.argsort(u_values) (1)
v_sorter = np.argsort(v_values)

all_values = np.concatenate((u_values, v_values)) (2)
all_values.sort(kind='mergesort')

deltas = np.diff(all_values) (3)

u_cdf_indices = u_values[u_sorter].searchsorted(all_values[:-1], 'right') (4)
v_cdf_indices = v_values[v_sorter].searchsorted(all_values[:-1], 'right')

v_cdf = v_cdf_indices / v_values.size (5)
u_cdf = u_cdf_indices / u_values.size

return np.sum(np.multiply(np.abs(u_cdf - v_cdf), deltas)) (6)

What is the reasoning behind this implementation, is there some literature? I did look at the paper cited which I believe explains why calculating the Wasserstein distance in its general definition in 1D is equivalent to evaluating the integral,


\int_{-\infty}^{+\infty} |U-V|,

with U and V the cumulative distribution functions for the distributions u_values and v_values,
but I don't understand how this integral is evaluated in scipy implementation.

In particular,
a) why are they multiplying by the deltas in (6) to solve the integral?
b) how are v_cdf and u_cdf in (5) the cumulative distribution functions U and V?

Also, with this implementation the element order of the distribution u_values and v_values is not preserved. Shouldn't this be the case in the general Wasserstein distance definition?

Thank you for your help!

1

There are 1 best solutions below

0
On

The order of the PDF, histogram or KDE is preserved and is important in Wasserstein distance. If you only pass the u_values and v_values then it has to calculate something like a PDF, KDE or histogram. Normally you would provide the PDF and the range of U and V as the 4 arguments to the function wasserstein_distance. So in the case where samples are provided you are not passing a real datapoint, simply a collection of repeated "experiments". Numbers 1 and 4 in your list of code blocks basically bins your data by the number of discrete values. A CDF is the number of discrete values until that point or P(x<X). The CDF is basically the cumulative sum of a PDF, histogram or KDE. Number 5 does the normalization of the CDF to between 0.0 and 1.0 or said another way it divides the bin by the number of bins.

So the order of the discrete values is preserved, not the original order in the datapoint.

B) It may make more sense if you plot the CDF's of a datapoint such as an image file by using the code above.

The transportation problem however may not need a PDF, but rather a datapoint of ordered features or some way to measure distance between features in which case you would calculate it differently.