Weighted sparse categorical cross entropy

3.3k Views Asked by At

I am dealing with a semantic segmentation problem where the two classes in which I am interested (in addition to background) are quiet unbalanced in the image pixels. I am actually using sparse categorical cross entropy as a loss, due to the way in which training masks are encoded. Is there any version of it which takes into account class weights? I have not been able to find it, and not even the original source code of sparse_categorical_cross_entropy. I never explored the tf source code before, but the link to source code from API page doesn't seem to link to a real implementation of the loss function.

3

There are 3 best solutions below

3
On

As far as I know you can use class weights in model.fit for any loss function. I have used it with categorical_cross_entropy and it works. It just weights the loss with the class weight so I see no reason it should not work with sparse_categorical_cross_entropy.

0
On

I think this is the solution to weigh sparse_categorical_crossentropy in Keras. They use the following to add a "second mask" (containing the weights for each class of the mask image) to the dataset.

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights


train_dataset.map(add_sample_weights).element_spec

Then they just use tf.keras.losses.SparseCategoricalCrossentropy as loss function and fit like:

weighted_model.fit(
    train_dataset.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
0
On

It seems that Keras Sparse Categorical Crossentropy doesn't work with class weights. I have found this implementation of sparse categorical cross-entropy loss for Keras, which is working to me. The implementation in the link had a little bug, which may be due to some version incompatibility, so I've fixed it.

    import tensorflow as tf
    from tensorflow import keras

    class WeightedSCCE(keras.losses.Loss):
        def __init__(self, class_weight, from_logits=False, name='weighted_scce'):
            if class_weight is None or all(v == 1. for v in class_weight):
                self.class_weight = None
            else:
                self.class_weight = tf.convert_to_tensor(class_weight,
                    dtype=tf.float32)
            self.name = name
            self.reduction = keras.losses.Reduction.NONE
            self.unreduced_scce = keras.losses.SparseCategoricalCrossentropy(
                from_logits=from_logits, name=name,
                reduction=self.reduction)
    
        def __call__(self, y_true, y_pred, sample_weight=None):
            loss = self.unreduced_scce(y_true, y_pred, sample_weight)
            if self.class_weight is not None:
                weight_mask = tf.gather(self.class_weight, y_true)
                loss = tf.math.multiply(loss, weight_mask)
            return loss

The loss should be called by taking as an argument the list or array of weights.