I'm trying to implement the Info NCE loss of this paper with my own image dataset. I'm following the implementation from this repo and using the following code:
def info_nce_loss(self, features):
labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(self.args.device)
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
logits = logits / self.args.temperature
return logits, labels
to train my model in a self-supervised manner. I was using batch_size
of 32 with the above loss function in my code and everything was working fine. But when I change the batch_size
to any other number for instance, 256, I get the following error:
The shape of the mask [512, 512] at index 0 does not match the shape of the indexed tensor [2, 2] at index 0.
The error originates at this line:
labels = labels[~mask].view(labels.shape[0], -1)
I tried resizing my images but that didn't help either. Any idea on what could be the issue here?
I found the solution. It was previously reported here and recently here. However, the code was still not updated. It seems this line:
is dependant on
batch_size
and was causing the issue. The above line needs to be replaced with either:labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)
OR:
labels = torch.cat([torch.arange(features.shape[0]//self.args.n_views) for i in range(self.args.n_views)], dim=0)
Based on that change, the correct loss function would be: