I have implemented a custom version of Batch Normalization with adding self.skip
variable that act somehow as trainable
. Here is the minimal code:
from tensorflow.keras.layers import BatchNormalization
import tensorflow as tf
# class CustomBN(tf.keras.layers.Layer):
class CustomBN(BatchNormalization):
def __init__(self, **kwargs):
super(CustomBN, self).__init__(**kwargs)
self.skip = False
def call(self, inputs, training=None):
if self.skip:
tf.print("I'm skipping")
else:
tf.print("I'm not skipping")
return super(CustomBN, self).call(inputs, training)
def build(self, input_shape):
super(CustomBN, self).build(input_shape)
To be crystal clear, all I have done so far are:
- sub classing
BatchNormalization
: should I sub classtf.keras.layers.Layer
? - defining
self.skip
to change the behavior ofCustomBN
layer in run time. - checking the state of
self.skip
incall
method to act correspondingly.
Now, to change the behavior of the 'CustomBN' layer, I use
self.model.layers[ind].skip = state
where state
is either True
or False
, and ind
is the index number of CustomBN
layer in the model
.
the evident problem is that the value of self.skip
will never change.
If you notice any mistakes please notify me.
By default, the call function in your layer will be called when the graph is built. Not on a per batch basis. Keras model
compile
method as arun_eagerly
option that would cause your model to run (slower) in eager mode which would invoke your call function without building a graph. This is most likely not what you want to do however.Ideally you want the flag that changes the behavior to be an input to the call method... For instance you can add an extra input to your graph which is simply this
state
flag and pass that to your layer.The following is an example of how you can have a conditional graph on an extra parameter.