I have a question regarding the mask propagation mechanism in Keras that I don't quite understand. As far as I know, if a custom layer does not implement self.compute_mask() and has self.support_masking = False, and if the preceding layer generates a mask, an exception is supposed to occur. This seems to be reflected in the implementation of the compute_mask() function in the keras/engine/base_layer.py source code.

# keras/engine/base_layer.py
@generic_utils.default
def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
    """Computes an output mask tensor.
    Args:
        inputs: Tensor or list of tensors.
        mask: Tensor or list of tensors.
    Returns:
        None or a tensor (or list of tensors,
            one per output tensor of the layer).
    """
    print("call base compute mask")
    if not self._supports_masking:
        if any(m is not None for m in nest.flatten(mask)):
            raise TypeError('Layer ' + self.name + ' does not support masking, '
                            'but was passed an input_mask: ' + str(mask))
        # masking not explicitly supported: return None as mask.
        return None
    # if masking is explicitly supported, by default
    # carry over the input mask
    return mask

However, when I slightly modify and experiment with the example in the tf.keras masking propagation documentation, it does not throw an exception under the conditions 1. compute_mask() not implemented, 2. self.support_masking = False. An exception is raised only when the base class's compute_mask() is forcibly called within call(). It seems that the base class's compute_mask() is not being called.

My question is why the base class's compute_mask() is not being called, and why there is code in the compute_mask() function to generate an exception if it seems to be unused.

I would appreciate it if someone could clarify this for me.

# MY TEST CODE
raw_inputs = [
    [711, 632, 71],
    [73, 8, 3215, 55, 927],
    [83, 91, 1, 645, 1253, 927],
]
padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs, padding="post")
unmasked_embedding = tf.cast(tf.tile(tf.expand_dims(padded_inputs, axis=-1), [1, 1, 10]), tf.float32)

class TemporalSplit(tf.keras.layers.Layer):
    def call(self, inputs):
        print('call')
        print(self._supports_masking) # False is printed
        return tf.split(inputs, 2, axis=1)

input = tf.keras.layers.Input(shape=(6, 10), name='input')
masked_input = tf.keras.layers.Masking(mask_value=0)(input)
# Receives mask from previous layer but 
# 1. does not implement compute_mask() 
# 2. self.support_masking = True is also not set
output = TemporalSplit()(masked_input)
model = tf.keras.Model(input, output)
print("model forward")
# No exception occurs
model(unmasked_embedding)

I wrote 'MY TEST CODE' expecting an exception to be thrown, but no exception was actually thrown.

0

There are 0 best solutions below