Select random non-zero element from each row of 3d matrix in non-eager tensorflow

136 Views Asked by At

I am trying to implement the triplet loss with random negative triplet selection. Right now I have a tensor of shape (batch_size, batch_size, batch_size) where element (i,j,k) is equal to dist(i,j) - dist(i,k) + margin (i is an anchor, j is a positive pair, k a negative).

I zero out all invalid elements and take the tf.maximum(tensor,0.)
Now for each pair i,j I want to randomly select a non-zero element if it exists, and calculate the mean of all these selected elements. I need for eager execution to be disabled, so I need not to iterate through anything.

Right now my code looks like this:

def random_negative_triplet_loss(labels, embeddings):

    margin = 1.
    # Get the pairwise distance matrix
    pairwise_dist = _pairwise_distances(embeddings)

    # shape (batch_size, batch_size, 1)
    anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
    assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape)
    # shape (batch_size, 1, batch_size)
    anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)
    assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape)

    # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
    # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
    # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
    # and the 2nd (batch_size, 1, batch_size)
    triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
    # Put to zero the invalid triplets
    # (where label(a) != label(p) or label(n) == label(a) or a == p)
    mask = _get_triplet_mask(labels)

    mask = tf.to_float(mask)
    triplet_loss = tf.multiply(mask, triplet_loss)
    # Remove negative losses (i.e. the easy triplets)
    triplet_loss = tf.maximum(triplet_loss, 0.0)
    num_classes = 5
    the_loss = 0
    the_count = 0
    num_valid = tf.reduce_sum(mask, axis=2)
    valid_count = tf.reduce_sum(tf.to_int32(tf.greater(num_valid, 1e-16)))
    sampler = tf.distributions.Uniform(0., tf.to_float(50) - 1e-3)

I assume that randomness can be achieved by using tf.distributions.Uniform, but since each pair i,j has different number of valid indexes k I don't know how to apply it.

0

There are 0 best solutions below