nce_loss()
asks for a static int
value for num_true
. That works well for problems where we have the same amount of labels per training example and we know it in advance.
When labels have a variable shape [None]
, and being batched and/or bucketed by bucket size with .padded_batch()
+ .group_by_window()
it is necessary to provide a variable size num_true
in order to accustom for all training examples. This is currently unsupported to my knowledge (correct me if I'm wrong).
In other words suppose we have either a dataset of images with an arbitrary amount of labels per each image (dog, cat, duck, etc.) or a text dataset with numerous multiple classes per sentence (class_1, class_2, ..., class_n). Classes are NOT mutually exclusive, and can vary in size between examples.
But as the amount of possible labels can be huge 10k-100k is there a way to do a sampling loss to improve performance (in comparison with a sigmoid_cross_entropy
)?
Is there a proper way to do this or any other workarounds?
nce_loss = tf.nn.nce_loss(
weights=nce_weights,
biases=nce_biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
# Something like this:
# `num_true=(tf.shape(labels)[-1])` instead of `num_true=const_int`
# , would be preferable here
num_classes=self.num_classes)
I see two issues: 1) Work with NCE with different numbers of true values; 2) Classes that are NOT mutually exclusive.
To the first issue, as @michal said, there is an expectative of including this functionality in the future. I have tried almost the same thing: to use labels with
shape=(None, None)
, i.e.,true_values
dimensionNone
. Thesampled_values
parameter has the same problem:true_values
number must be a fixed integer number. The recomended work around is to use a class (0
is the best one) representing<PAD>
and complete the number oftrue_values
. In my case,0
is an special token that represents<PAD>
. Part of code is here:I sorted the label because considering another recommendation:
In my case, the special tokens and more frequent words have lower indexes, otherwise, less frequent words have higher indexes. I included all label classes associated to the input at same time and completed with zero till the
true_values
number. Of course, you must ignore the0
class at the end.