I have implemented a one class autoencoder that is described to have a Batchwise Feature Weighting Block that works as an attention mechanism which helps in identifying salient features that are useful for classification. However my model doesn't seem to be learning well enough and is stagnant when trained with the BFW layer. If I removed it, it trains very well. What could be the issue.
from keras.optimizers import Adam
from keras.layers import Input, Dense, LeakyReLU, Reshape, Flatten, Multiply
from keras.models import Model
import keras.backend as K
from keras.layers import Layer
import matplotlib.pyplot as plt
class BFW(Layer):
"""
Custom layer implementing w = softmax(1/B * ∑(i=1 to B) (W2 · σ(W1 · zi + b1) + b2))
where zi = σ(W0 · xi + b0)
and xi is the ith input vector
and B is the batch size
and σ is the sigmoid function
and W1, W2, b1, b2 are the learnable parameters
and w is the weight vector
"""
def __init__(self, L, **kwargs):
self.L = L
self.L_prime = L // 2 # L' is half of L
super(BFW, self).__init__(**kwargs)
def build(self, input_shape):
self.W1 = self.add_weight(shape=(input_shape[1], self.L_prime),
initializer='glorot_uniform',
trainable=True)
self.b1 = self.add_weight(shape=(self.L_prime,),
initializer='zeros',
trainable=True)
self.W2 = self.add_weight(shape=(self.L_prime, 1),
initializer='glorot_uniform',
trainable=True)
self.b2 = self.add_weight(shape=(1,),
initializer='zeros',
trainable=True)
super(BFW, self).build(input_shape)
def compute_weighted_sum(self, x_i):
x_i = K.expand_dims(x_i, axis=0)
z_i = K.dot(x_i, self.W1)
z_i = K.bias_add(z_i, self.b1)
z_i = K.relu(z_i)
weighted_sum_i = K.dot(z_i, self.W2) + self.b2
return weighted_sum_i
def call(self, x):
B = K.shape(x)[0] # batch size
batch_weights = K.map_fn(self.compute_weighted_sum, x)
scaled_batch_weights = batch_weights / tf.cast(B, tf.float32)
w = K.softmax(scaled_batch_weights)
return w
# Define hyperparameters
input_shape = (SIZE, SIZE, 1)
# Encoder
encoder_input = Input(shape=input_shape)
x = Flatten()(encoder_input)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)
# Define the encoder output with the desired latent dimension
encoder_output = Dense(256)(x)
# Apply encoder output to the BFW layer
bfw_layer = BFW(L=256)(encoder_output)
# Decoder
x = Dense(512)(bfw_layer) # Use the weighted latent representation
x = LeakyReLU(alpha=0.02)(x)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)
x = Dense(SIZE * SIZE * 1, activation='sigmoid')(x)
decoded = Reshape((SIZE, SIZE, 1))(x)
# Define the autoencoder model
autoencoder = Model(encoder_input, decoded)
# Compile the model
optimizer = Adam(learning_rate=0.001)
autoencoder.compile(optimizer=optimizer, loss='mean_squared_error')
# Print the model summary
autoencoder.summary()
# Fit the model.
history = autoencoder.fit(
train_generator,
steps_per_epoch=6817 // batch_size,
epochs=35,
validation_data=validation_generator,
validation_steps=6282 // batch_size,
shuffle=True)
# Plot the training and validation loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
Here is how it trains
Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_16 (InputLayer) [(None, 32, 32, 1)] 0
flatten_15 (Flatten) (None, 1024) 0
dense_84 (Dense) (None, 512) 524800
leaky_re_lu_56 (LeakyReLU) (None, 512) 0
dense_85 (Dense) (None, 512) 262656
leaky_re_lu_57 (LeakyReLU) (None, 512) 0
dense_86 (Dense) (None, 256) 131328
bfw_15 (BFW) (None, 1, 1) 33025
dense_87 (Dense) (None, 1, 512) 1024
leaky_re_lu_58 (LeakyReLU) (None, 1, 512) 0
dense_88 (Dense) (None, 1, 512) 262656
leaky_re_lu_59 (LeakyReLU) (None, 1, 512) 0
dense_89 (Dense) (None, 1, 1024) 525312
reshape_13 (Reshape) (None, 32, 32, 1) 0
=================================================================
Total params: 1,740,801
Trainable params: 1,740,801
Non-trainable params: 0
_________________________________________________________________
Epoch 1/35
106/106 [==============================] - 11s 94ms/step - loss: 0.0354 - val_loss: 0.0301
Epoch 2/35
106/106 [==============================] - 9s 86ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 3/35
106/106 [==============================] - 9s 84ms/step - loss: 0.0291 - val_loss: 0.0302
Epoch 4/35
106/106 [==============================] - 9s 81ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 5/35
106/106 [==============================] - 8s 79ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 6/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 7/35
106/106 [==============================] - 8s 76ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 8/35
106/106 [==============================] - 8s 75ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 9/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 10/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 11/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 12/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0302
Epoch 13/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 14/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 15/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 16/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 17/35
106/106 [==============================] - 8s 74ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 18/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 19/35
106/106 [==============================] - 8s 75ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 20/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0292 - val_loss: 0.0301
Epoch 21/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 22/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0299
Epoch 23/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 24/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 25/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 26/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 27/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 28/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 29/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 30/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 31/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 32/35
106/106 [==============================] - 8s 74ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 33/35
106/106 [==============================] - 8s 73ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 34/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 35/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
FYI - This is a one class autoencoder that is only trained on the normal class instance.