Keras Custom Batch Normalization layer with an extra variable that can be changed in run time

1.1k Views Asked by At

I have implemented a custom version of Batch Normalization with adding self.skip variable that act somehow as trainable. Here is the minimal code:

from tensorflow.keras.layers import BatchNormalization
import tensorflow as tf

# class CustomBN(tf.keras.layers.Layer):

class CustomBN(BatchNormalization):

    def __init__(self, **kwargs):
        super(CustomBN, self).__init__(**kwargs)
        self.skip = False

    def call(self, inputs, training=None):
        if self.skip:
            tf.print("I'm skipping")
        else:
            tf.print("I'm not skipping")
        return super(CustomBN, self).call(inputs, training)

    def build(self, input_shape):
        super(CustomBN, self).build(input_shape)

To be crystal clear, all I have done so far are:

  1. sub classing BatchNormalization: should I sub class tf.keras.layers.Layer?
  2. defining self.skip to change the behavior of CustomBN layer in run time.
  3. checking the state of self.skip in call method to act correspondingly.

Now, to change the behavior of the 'CustomBN' layer, I use

self.model.layers[ind].skip = state

where state is either True or False, and ind is the index number of CustomBN layer in the model. the evident problem is that the value of self.skip will never change. If you notice any mistakes please notify me.

1

There are 1 best solutions below

0
On

By default, the call function in your layer will be called when the graph is built. Not on a per batch basis. Keras model compile method as a run_eagerly option that would cause your model to run (slower) in eager mode which would invoke your call function without building a graph. This is most likely not what you want to do however.

Ideally you want the flag that changes the behavior to be an input to the call method... For instance you can add an extra input to your graph which is simply this state flag and pass that to your layer.

The following is an example of how you can have a conditional graph on an extra parameter.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class MyLayerWithFlag(keras.layers.Layer):
  def call(self, inputs, flag=None):
    c_one = tf.constant([1], dtype=tf.float32)

    if flag is not None:
      x = tf.cond(
          flag, lambda: tf.math.add(inputs, c_one), 
          lambda: inputs)
      return x
    return inputs

inputs = layers.Input(shape=(2,))
state = layers.Input(shape=(1,), dtype=tf.bool)
x = MyLayerWithFlag()(inputs, flag=state)
out = layers.Lambda(tf.reduce_sum)(x)
model = keras.Model([inputs, state], out)

data = np.array([[1., 2.]])
state = np.array([[True]])
model.predict((data, state))