I am training a transfer learning model (DenseNet121) with two inputs (imaging data) and want to plot GRAD-CAM to visualize the activation maps. I am using 3D mages (shape: 128 x 128 x 36) and as the transfer learning models are 2D, I took the stacks of 3-slices to train the model which I later concatenated, so the prediction image has the shape:(2,12, 224, 224, 3)

My code is as follows:

base_model = load_model(model_dir) #pre-trained model is in model_dir
base_model.trainable = False

input1 = Input(shape=(image_size, image_size, 3), name='inp1')
x = data_augmentation(input1)
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x1 = Dropout(0.5)(x)

input2 = Input(shape=(image_size, image_size, 3), name='inp2')
x = data_augmentation(input2)
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x2 = Dropout(0.5)(x)

concatenated = layers.concatenate([x1, x2], axis=-1)
output = layers.Dense(1, activation='sigmoid')(concatenated)
model = Model([input1, input2], output)

I tried to plot GRAD-CAM by following this example in Keras: https://keras.io/examples/vision/grad_cam/

While doing this for the multi-input network, I am getting an error:

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'") at layer "zero_padding2d_2". The following previous layers were accessed without issue: []

I believe the error is from this part of the code:

grad_model = keras.models.Model(
        model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
    )

I also tried changing the above to the following, but still getting the same error:

grad_model = tf.keras.models.Model(
        [model.get_layer('inp1').input, model.get_layer('inp2').input], [model.get_layer('densenet121').get_layer('conv5_block16_2_conv').output, model.output]) 

How can I change the above in the case of multi-input models in transfer learning?

0

There are 0 best solutions below