How to remove batch normalization layers from a Keras model?

391 Views Asked by At

I would like to remove all the batch normalization layers from a Keras model that includes short-skip connections. For example, let's consider EfficientNetB0 as follow:

import tensorflow as tf

model = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)

I was using the 3D version of efficientnet, which I don't think is important for the question ,but I'm going to show it anyway:


import keras 
from keras import layers
from keras.models import Model


input_shape = (32,32,32,3)

import efficientnet_3D.keras as efn 
model = efn.EfficientNetB0(input_shape=input_shape, weights='imagenet')

X = model.layers[-1].output
X = layers.Flatten()(X)
X = layers.Dense(16)(X)
X = layers.Dense(16)(X)
X = layers.Dense(1)(X)

model = Model(inputs=model.inputs, outputs=X)
model.compile(loss='mse',
              optimizer='adam',
              metrics=['mean_absolute_error']
              )
model.summary()

I tried to develop my own way of removing, and it seems to be totally wrong. Because the output model is pretty messy in terms of all the shortcut connections.



import keras
from keras import layers
from keras.models import Model

ind = [i for i, l in enumerate(model.layers) if 'bn' in l.name]


X = model.layers[0].output
for i in range(1, len(model.layers)+1):
    
    # Skipping Batch Normalization layers
    if i in ind:
        # model.layers[i]._inbound_nodes = []
        # model.layers[i]._outbound_nodes = []
        continue
        
    # If there is a short skip 
    if isinstance(model.layers[i].input, list):
        input_names = [j.name for j in model.layers[i].input]
        assert len(input_names) == 2
        input_names.remove(X.name)
        input_names = input_names[0].split('/')[0] 
        # X = [model.get_layer(input_names).output, X]
        X = [model.layers[6].output, X]
        
    if isinstance(X, list):
        print(i)
    X = model.layers[i](X)

new_model = Model(inputs=model.inputs, outputs=X)

I think there should be a better way that I'm not aware of. I tried a similar question for removing a layer, but I think because my model includes skip-connection, those methods don't work. Any help is appreciated.

0

There are 0 best solutions below