Shape mismatch when implementing Info NCE loss

82 Views Asked by At

I'm trying to implement the Info NCE loss of this paper with my own image dataset. I'm following the implementation from this repo and using the following code:

def info_nce_loss(self, features):

        labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.args.device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

        logits = logits / self.args.temperature
        return logits, labels

to train my model in a self-supervised manner. I was using batch_size of 32 with the above loss function in my code and everything was working fine. But when I change the batch_size to any other number for instance, 256, I get the following error:

The shape of the mask [512, 512] at index 0 does not match the shape of the indexed tensor [2, 2] at index 0. 

The error originates at this line:

labels = labels[~mask].view(labels.shape[0], -1)

I tried resizing my images but that didn't help either. Any idea on what could be the issue here?

1

There are 1 best solutions below

0
On

I found the solution. It was previously reported here and recently here. However, the code was still not updated. It seems this line:

labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)

is dependant on batch_size and was causing the issue. The above line needs to be replaced with either: labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

OR: labels = torch.cat([torch.arange(features.shape[0]//self.args.n_views) for i in range(self.args.n_views)], dim=0)

Based on that change, the correct loss function would be:

def info_nce_loss(self, features):

        labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.args.device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

        logits = logits / self.args.temperature
        return logits, labels