Triplet Network, loss function and equal distances

482 Views Asked by At

I'm currently implementing a triplet network to recognise if two images are describing the same 3d-model or not, but I have some problems with the results, the distances between anchor-positive is always equal to the distance between anchor-negative.

Here the code of my loss function :

    def triplet_loss(self):
    self.d_pos = tf.reduce_sum(tf.square(self.o1 - self.o2), axis=-1)
    self.d_neg = tf.reduce_sum(tf.square(self.o1 - self.o3), axis=-1)

    loss = tf.maximum(0.0, self.margin + (self.d_pos - self.d_neg))
    loss = tf.reduce_mean(loss)

    return loss

Where o1, o2 and o3 are the output of convolutional networks with shared weights and are batch normalized :

output = tf.layers.batch_normalization(inputs=output, axis=-1, momentum=0.9, epsilon=0.0001, center=True, scale=True, name='batch_3_norm')

And the first results are the followings :

epoch 0:    batch:0   loss 0.0000199945   dneg : 0.079995   dpos; 0.079995 

epoch 0:    batch:1   loss 0.0000201295   dneg : 0.092946   dpos; 0.092946

epoch 0:    batch:2   loss 0.0000205572   dneg : 0.110583   dpos; 0.110583 

epoch 0:    batch:3   loss 0.0000216728   dneg : 0.122692   dpos; 0.122693 

epoch 0:    batch:4   loss 0.0000202223   dneg : 0.111207   dpos; 0.111207 

epoch 0:    batch:5   loss 0.0000200346   dneg : 0.105684   dpos; 0.105684 
############### Test set : batch:5   loss 0.000 

epoch 1:    batch:0   loss 0.0000207106   dneg : 0.105736   dpos; 0.105737 

epoch 1:    batch:1   loss 0.0000200992   dneg : 0.107299   dpos; 0.107299 

epoch 1:    batch:2   loss 0.0000207007   dneg : 0.111667   dpos; 0.111667 

epoch 1:    batch:3   loss 0.0000201932   dneg : 0.109080   dpos; 0.109081 

epoch 1:    batch:4   loss 0.0000206707   dneg : 0.111295   dpos; 0.111295 

(dneg and dpos are the distances for positive and negative couples)

So many questions :

  • how to tune the margin? So the difference between the two distances is so small that I have to put a very small margin?

  • Because the two distances are equals, the loss is equal to the margin. How avoid this issue?

  • how to measure the accuracy of a triplet network? For example if a batch of size 100, can we count the number of negative examples, which have a distance to the anchor bigger than the distance between anchor and positive + margin?

Thanks a lot for your answers!

0

There are 0 best solutions below