I have a question regarding the mask propagation mechanism in Keras that I don't quite understand. As far as I know, if a custom layer does not implement self.compute_mask() and has self.support_masking = False, and if the preceding layer generates a mask, an exception is supposed to occur. This seems to be reflected in the implementation of the compute_mask() function in the keras/engine/base_layer.py source code.
# keras/engine/base_layer.py
@generic_utils.default
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
"""Computes an output mask tensor.
Args:
inputs: Tensor or list of tensors.
mask: Tensor or list of tensors.
Returns:
None or a tensor (or list of tensors,
one per output tensor of the layer).
"""
print("call base compute mask")
if not self._supports_masking:
if any(m is not None for m in nest.flatten(mask)):
raise TypeError('Layer ' + self.name + ' does not support masking, '
'but was passed an input_mask: ' + str(mask))
# masking not explicitly supported: return None as mask.
return None
# if masking is explicitly supported, by default
# carry over the input mask
return mask
However, when I slightly modify and experiment with the example in the tf.keras masking propagation documentation, it does not throw an exception under the conditions 1. compute_mask() not implemented, 2. self.support_masking = False. An exception is raised only when the base class's compute_mask() is forcibly called within call(). It seems that the base class's compute_mask() is not being called.
My question is why the base class's compute_mask() is not being called, and why there is code in the compute_mask() function to generate an exception if it seems to be unused.
I would appreciate it if someone could clarify this for me.
# MY TEST CODE
raw_inputs = [
[711, 632, 71],
[73, 8, 3215, 55, 927],
[83, 91, 1, 645, 1253, 927],
]
padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs, padding="post")
unmasked_embedding = tf.cast(tf.tile(tf.expand_dims(padded_inputs, axis=-1), [1, 1, 10]), tf.float32)
class TemporalSplit(tf.keras.layers.Layer):
def call(self, inputs):
print('call')
print(self._supports_masking) # False is printed
return tf.split(inputs, 2, axis=1)
input = tf.keras.layers.Input(shape=(6, 10), name='input')
masked_input = tf.keras.layers.Masking(mask_value=0)(input)
# Receives mask from previous layer but
# 1. does not implement compute_mask()
# 2. self.support_masking = True is also not set
output = TemporalSplit()(masked_input)
model = tf.keras.Model(input, output)
print("model forward")
# No exception occurs
model(unmasked_embedding)
I wrote 'MY TEST CODE' expecting an exception to be thrown, but no exception was actually thrown.