Variable size multi-label candidate sampling in tensorflow?

508 Views Asked by At

nce_loss() asks for a static int value for num_true. That works well for problems where we have the same amount of labels per training example and we know it in advance.

When labels have a variable shape [None], and being batched and/or bucketed by bucket size with .padded_batch() + .group_by_window() it is necessary to provide a variable size num_true in order to accustom for all training examples. This is currently unsupported to my knowledge (correct me if I'm wrong).

In other words suppose we have either a dataset of images with an arbitrary amount of labels per each image (dog, cat, duck, etc.) or a text dataset with numerous multiple classes per sentence (class_1, class_2, ..., class_n). Classes are NOT mutually exclusive, and can vary in size between examples. But as the amount of possible labels can be huge 10k-100k is there a way to do a sampling loss to improve performance (in comparison with a sigmoid_cross_entropy)?

Is there a proper way to do this or any other workarounds?

nce_loss = tf.nn.nce_loss(
    weights=nce_weights,
    biases=nce_biases,
    labels=labels,
    inputs=inputs,
    num_sampled=num_sampled,
    # Something like this: 
    # `num_true=(tf.shape(labels)[-1])` instead of `num_true=const_int`
    # , would be preferable here
    num_classes=self.num_classes)
1

There are 1 best solutions below

0
On

I see two issues: 1) Work with NCE with different numbers of true values; 2) Classes that are NOT mutually exclusive.

To the first issue, as @michal said, there is an expectative of including this functionality in the future. I have tried almost the same thing: to use labels with shape=(None, None), i.e., true_values dimension None. The sampled_values parameter has the same problem: true_values number must be a fixed integer number. The recomended work around is to use a class (0 is the best one) representing <PAD> and complete the number of true_values. In my case, 0 is an special token that represents <PAD>. Part of code is here:

assert len(labels) <= (window_size * 2)
zeros = ((window_size * 2) - len(labels)) * [0]
labels = labels + zeros
labels.sort()

I sorted the label because considering another recommendation:

Note: By default this uses a log-uniform (Zipfian) distribution for sampling, so your labels must be sorted in order of decreasing frequency to achieve good results.

In my case, the special tokens and more frequent words have lower indexes, otherwise, less frequent words have higher indexes. I included all label classes associated to the input at same time and completed with zero till the true_values number. Of course, you must ignore the 0 class at the end.