problem with inputs for estimating earth mover distance with emd from python ot package

93 Views Asked by At

when implementing the emd and emd2 functions from the ot package here I'm a little confused on how to input a,b to be consistent with M, the cost matrix.

without the a,b inputs, things work as I expect them:

import ot
import numpy as np

np.random.seed(943)
d1 = np.random.normal(0,1,size=(50,2)) #multivariate distribution 1: Nxk, N=50, k=2
d2 = np.random.normal(2,1,size=(50,2)) #multivariate distribution 2: Nxk, N=50, k=2
C = ot.dist(d1,d2,metric='euclidean') #cost matrix with elements |i-j|
print(C.shape) #NxN
otp = ot.emd([],[],C) #optimal transport plan, also NxN
print('emd: ',np.sum(otp*C)) #emd estimate
emd = ot.emd2([],[],C)
print('emd: ',emd) #or obtain it directly with emd2
#emd values from to preceding print statements
#emd:  2.5941590559952763
#emd:  2.5941590559952763

This works as the a,b inputs are left empty, which according to the documentation means "uniform weights if empty list" are used. My first question is, as this means this assumes a uniform distribution for each of the marginals, isn't this an incorrect emd as the marginals are clearly not uniform? My (potentially incorrect) assumption therefore is that leaving a,b empty is not the correct distance.

I did see another question posted here and the solution there also produces the same emd value as the 2 emd values estimated above using only information from the cost matrix as input:

from scipy.optimize import linear_sum_assignment
assignment = linear_sum_assignment(C)
print(C[assignment].sum() / 50) #emd value of 2.594159055995276

Now, when I actually try to input a,b is where my real confusion lies. If I just input the raw observations, the arrays d1 and d2 I get the following pretty expected error

AssertionError: 
Arrays are not almost equal to 6 decimals
a and b vector must have the same sum

I believe it expects distributions as the docs say these inputs are "Source histogram"; but if I convert the raw data to histograms the shapes are no longer compatible with the cost matrix, as the cost matrix is NxN and the histograms are n_bins x n_bins, leading to another error

hist1, _ = np.histogramdd(d1,bins=[2,2],density=True) #shape n_bins x n_bins = 2x2
hist2, _ = np.histogramdd(d2,bins=[2,2],density=True) #shape n_bins x n_bins = 2x2; while C is NxN
ot.emd2(hist1, hist2, C)

AssertionError: Dimension mismatch, check dimensions of M with a and b

Can anyone help me understand where I'm going wrong either with my implementation and how to correctly implement? ...or if the first implementations without a and b are correct clarify what the role of a and b are and why the emd is correct without those inputs?

0

There are 0 best solutions below