Using weighted minkowski metric in sklearn's BallTree

726 Views Asked by At

I have been playing around with BallTree and the different metrics that it provides. However, when I use wminkowski, it seems the weights have no impact whatsoever on the outcome:

df = pd.DataFrame()
num_features = 4
num_samples = 100
for i in range(num_features):
    df['A_%s'%(i+1)] = np.random.rand(num_samples)
    df['A_%s'%(i+1)] = df['A_%s'%(i+1)].apply(lambda x: 500 - (1000 * x ** 3))
point = np.array([int(1000 * r ** 3) for r in np.random.rand(num_features)]).reshape(1, -1)
weights = [int(10000 * r ** 2) for r in np.random.rand(num_features)]

tree1 = sklearn.neighbors.BallTree(df, metric='minkowski')
tree2 = sklearn.neighbors.BallTree(df, metric='wminkowski', p=2, w=[1] * num_features) # Should be just like tree1
tree3 = sklearn.neighbors.BallTree(df, metric='wminkowski', p=2, w=weights)

d1, i1 = tree1.query(point, k=5)
d2, i2 = tree2.query(point, k=5)
d3, i3 = tree2.query(point, k=5)

print 'Point:'
print point
print 'Weights:'
print weights
print 'Distances:'
print d1
print d2
print d3
print 'Indices:'
print i1
print i2
print i3

And the output is:

Point:
[[ 16  58   0 884]]
Weights:
[2869, 46, 1558, 5835]
Distances:
[[ 451.55203926  537.61234492  601.29840519  601.74059138  647.46934474]]
[[ 451.55203926  537.61234492  601.29840519  601.74059138  647.46934474]]
[[ 451.55203926  537.61234492  601.29840519  601.74059138  647.46934474]]
Indices:
[[61 31 86 43 93]]
[[61 31 86 43 93]]
[[61 31 86 43 93]]

I have tried to run the above code with different numbers of features and samples and I every time all three trees return the exact same output, while I expect the output returned by tree3 to be different. Why is that? I am using sklearn version 0.18.1.

1

There are 1 best solutions below

1
On BEST ANSWER

My guess is because in the example you're assigning tree2 to d3,i3 -- the offending line:

d3, i3 = tree2.query(point, k=5)

When you probably mean:

d3, i3 = tree3.query(point, k=5)

Changing tree2 to tree3, provided different results in tree3.