Can't fold BatchNorm with Conv2D in Keras QAT basic example

136 Views Asked by At

I'm currently trying to use Keras' Quantization Aware Training, specifically because I need to do 8bit inference on a low-precision device. For this reason, I need to fold the batch norm onto the Convolution to avoid having the 32-bit moving mean and variance. The sample code I'm starting with is the following (tf1.15, tensorflow-model-optimization 0.6.0):

    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(224, 224, 3)),
        tf.keras.layers.Conv2D(filters=3, kernel_size=(3, 3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1000)
    ])



    quantize_model = tfmot.quantization.keras.quantize_model

    # q_aware stands for for quantization aware.
    q_aware_model = quantize_model(model)

    # `quantize_model` requires a recompile.
    q_aware_model.compile(optimizer='adam',
                loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=smooth),
                metrics=['accuracy'])

    q_aware_model.summary()

The documentation states that 'Conv2D+BN+ReLU' should have the BatchNorm folded but that isn't the case in the .h5 file produced.

0

There are 0 best solutions below