How to determine accuracy with triplet loss in a convolutional neural network

7.4k Views Asked by At

A Triplet network (inspired by "Siamese network") is comprised of 3 instances of the same feed-forward network (with shared parameters). When fed with 3 samples, the network outputs 2 intermediate values - the L2 (Euclidean) distances between the embedded representation of two of its inputs from the representation of the third.

I'm using pairs of three images for feeding the network (x = anchor image, a standard image, x+ = positive image, an image containing the same object as x - actually, x+ is same class as x, and x- = negative image, an image with different class than x.

enter image description here

I'm using the triplet loss cost function described here.

How do I determine the network's accuracy?

2

There are 2 best solutions below

11
On BEST ANSWER

I am assuming that your are doing work for image retrieval or similar tasks.

You should first generate some triplet, either randomly or using some hard (semi-hard) negative mining method. Then you split your triplet into train and validation set.

If you do it this way, then you can define your validation accuracy as proportion of the number of triplet in which feature distance between anchor and positive is less than that between anchor and negative in your validation triplet. You can see an example here which is written in PyTorch.

As another way, you can directly measure in term of your final testing metric. For example, for image retrieval, typically, we measure the performance of model on test set using mean average precision. If you use this metric, you should first define some queries on your validation set and their corresponding ground truth image.

Either of the above two metric is fine. Choose whatever you think fit your case.

0
On

So I am performing a similar task of using Triplet loss for classification. Here is how I used the novel loss method with a classifier. First, train your model using the standard triplet loss function for N epochs. Once you are sure that the model ( we shall refer to this as the embedding generator) is trained, save the weights as we shall be using these weights ahead. Let's say that your embedding generator is defined as:

class EmbeddingNetwork(nn.Module):
def __init__(self):
    super(EmbeddingNetwork, self).__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(1, 64, (7,7), stride=(2,2), padding=(3,3)),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.001),
        nn.MaxPool2d((3, 3), 2, padding=(1,1))
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(64,64,(1,1), stride=(1,1)),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.001),
        nn.Conv2d(64,192, (3,3), stride=(1,1), padding=(1,1)),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.001),
        nn.MaxPool2d((3,3),2, padding=(1,1))
    )
    self.fullyConnected = nn.Sequential(
        nn.Linear(7*7*256,32*128),
        nn.BatchNorm1d(32*128),
        nn.LeakyReLU(0.001),
        nn.Linear(32*128,128)
    )
def forward(self,x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = self.fullyConnected(x)
  return torch.nn.functional.normalize(x, p=2, dim=-1)

Now we shall using this embedding generator to create another classifier, fit the weights we saved before to this part of the network and then freeze this part so our classifier trainer does not interfere with the triplet model. This can be done as:

class classifierNet(nn.Module):
def __init__(self, EmbeddingNet):
    super(classifierNet, self).__init__()
    self.embeddingLayer = EmbeddingNet
    self.classifierLayer = nn.Linear(128,62)
    self.dropout = nn.Dropout(0.5)

def forward(self, x):
    x = self.dropout(self.embeddingLayer(x))
    x = self.classifierLayer(x)
    return F.log_softmax(x, dim=1)

Now we shall load the weights we saved before and freeze them using:

embeddingNetwork = EmbeddingNetwork().to(device)
embeddingNetwork.load_state_dict(torch.load('embeddingNetwork.pt'))
classifierNetwork = classifierNet(embeddingNetwork)

Now train this classifier network using the standard classification losses like BinaryCrossEntropy or CrossEntropy.