Keras: triplet loss with positive and negative sample within batch

2.7k Views Asked by At

I try to refactor my Keras code to use 'Batch Hard' sampling for the triplets, as proposed in https://arxiv.org/pdf/1703.07737.pdf.

" the core idea is to form batches by randomly sampling P classes (person identities), and then randomly sampling K images of each class (person), thus resulting in a batch of PK images. Now, for each sample a in the batch, we can select the hardest positive and the hardest negative samples within the batch when forming the triplets for computing the loss, which we call Batch Hard"

So at the moment I have a Python generator (for use with model.fit_generator in Keras) which produces batches on the CPU. Then the actual forward and backward passes through the model could be done on the GPU.

However, how to make this fit with the 'Batch Hard' method? The generator samples 64 images, for which 64 triplets should be formed. First a forward pass is required to obtain the 64 embeddings with the current model.

    embedding_model = Model(inputs = input_image, outputs = embedding)

But then the hardest positive and hardest negative have to be selected from the 64 embeddings to form triplets. Then the loss can be computed

    anchor = Input(input_shape, name='anchor')
    positive = Input(input_shape, name='positive')
    negative = Input(input_shape, name='negative')

    f_anchor = embedding_model(anchor)
    f_pos = embedding_model(pos)
    f_neg = embedding_model(neg)

    triplet_model = Model(inputs = [anchor, positive, negative], outputs=[f_anchor, f_pos, f_neg])

And this triplet_model can be trained by defining a triplet loss function. However, is it possible with Keras to use the fit_generator and the 'Batch Hard' method? Or how to obtain access to the embeddings from the other samples in the batch?

Edit: With keras.layers.Lambda I can define an own layer creating triplets with input (batch_size, height, width, 3) and output (batch_size, 3, height, width, 3), but I also need access to the id's somewhere. Is this possible within the layer?

0

There are 0 best solutions below