Generic Computation of Distance Matrices in Pytorch

215 Views Asked by At

I have two tensors a & b of shape (m,n), and I would like to compute a distance matrix m using some distance metric d. That is, I want m[i][j] = d(a[i], b[j]). This is somewhat like cdist(a,b) but assuming a generic distance function d which is not necessarily a p-norm distance. Is there a generic way to implement this in PyTorch?

And a more specific side question: Is there an efficient way to perform this with the following metric

d(x,y) = 1 - cos(x,y)

edit

I've solved the specific case above using this answer:

def metric(a, b, eps=1e-8):
    a_norm, b_norm = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_norm, eps * torch.ones_like(a_norm))
    b_norm = b / torch.max(b_norm, eps * torch.ones_like(b_norm))
    similarity_matrix = torch.mm(a_norm, b_norm.transpose(0, 1))
    return 1 - similarity_matrix
1

There are 1 best solutions below

0
On BEST ANSWER

I'd suggest using broadcasting: since a,b both have shape (m,n) you can compute

m = d( a[None, :, :], b[:, None, :])

where d needs to operate on the last dimension, so for instance

def d(a,b): return 1 - (a * b).sum(dim=2) / a.pow(2).sum(dim=2).sqrt() / b.pow(2).sum(dim=2).sqrt()

(here I assume that cos(x,y) represents the normalized inner product between x and y)