One Class Autoencoder

21 Views Asked by At

I have implemented a one class autoencoder that is described to have a Batchwise Feature Weighting Block that works as an attention mechanism which helps in identifying salient features that are useful for classification. However my model doesn't seem to be learning well enough and is stagnant when trained with the BFW layer. If I removed it, it trains very well. What could be the issue.


from keras.optimizers import Adam
from keras.layers import Input, Dense, LeakyReLU, Reshape, Flatten, Multiply
from keras.models import Model
import keras.backend as K
from keras.layers import Layer
import matplotlib.pyplot as plt

class BFW(Layer):
    """
    Custom layer implementing w = softmax(1/B * ∑(i=1 to B) (W2 · σ(W1 · zi + b1) + b2))
    where zi = σ(W0 · xi + b0)
    and xi is the ith input vector
    and B is the batch size
    and σ is the sigmoid function
    and W1, W2, b1, b2 are the learnable parameters
    and w is the weight vector
    """

    def __init__(self, L, **kwargs):
        self.L = L
        self.L_prime = L // 2  # L' is half of L
        super(BFW, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W1 = self.add_weight(shape=(input_shape[1], self.L_prime), 
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.b1 = self.add_weight(shape=(self.L_prime,),
                                  initializer='zeros',
                                  trainable=True)
        self.W2 = self.add_weight(shape=(self.L_prime, 1),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.b2 = self.add_weight(shape=(1,), 
                                  initializer='zeros',
                                  trainable=True)
        super(BFW, self).build(input_shape)

    def compute_weighted_sum(self, x_i):
        x_i = K.expand_dims(x_i, axis=0)
        
        z_i = K.dot(x_i, self.W1)
        z_i = K.bias_add(z_i, self.b1)
        z_i = K.relu(z_i)
        
        weighted_sum_i = K.dot(z_i, self.W2) + self.b2

        return weighted_sum_i
    
    def call(self, x):
        B = K.shape(x)[0] # batch size
        
        batch_weights = K.map_fn(self.compute_weighted_sum, x)
        
        scaled_batch_weights = batch_weights / tf.cast(B, tf.float32)
        
        w = K.softmax(scaled_batch_weights)
        
        return w

# Define hyperparameters
input_shape = (SIZE, SIZE, 1)

# Encoder
encoder_input = Input(shape=input_shape)

x = Flatten()(encoder_input)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)

x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)

# Define the encoder output with the desired latent dimension
encoder_output = Dense(256)(x)

# Apply encoder output to the BFW layer
bfw_layer = BFW(L=256)(encoder_output)

# Decoder
x = Dense(512)(bfw_layer)  # Use the weighted latent representation
x = LeakyReLU(alpha=0.02)(x)

x = Dense(512)(x)
x = LeakyReLU(alpha=0.02)(x)

x = Dense(SIZE * SIZE * 1, activation='sigmoid')(x)
decoded = Reshape((SIZE, SIZE, 1))(x)

# Define the autoencoder model
autoencoder = Model(encoder_input, decoded)

# Compile the model
optimizer = Adam(learning_rate=0.001)
autoencoder.compile(optimizer=optimizer, loss='mean_squared_error')

# Print the model summary
autoencoder.summary()

# Fit the model.
history = autoencoder.fit(
    train_generator,
    steps_per_epoch=6817 // batch_size,
    epochs=35,
    validation_data=validation_generator,
    validation_steps=6282 // batch_size,
    shuffle=True)

# Plot the training and validation loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

Here is how it trains

Model: "model_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_16 (InputLayer)       [(None, 32, 32, 1)]       0         
                                                                 
 flatten_15 (Flatten)        (None, 1024)              0         
                                                                 
 dense_84 (Dense)            (None, 512)               524800    
                                                                 
 leaky_re_lu_56 (LeakyReLU)  (None, 512)               0         
                                                                 
 dense_85 (Dense)            (None, 512)               262656    
                                                                 
 leaky_re_lu_57 (LeakyReLU)  (None, 512)               0         
                                                                 
 dense_86 (Dense)            (None, 256)               131328    
                                                                 
 bfw_15 (BFW)                (None, 1, 1)              33025     
                                                                 
 dense_87 (Dense)            (None, 1, 512)            1024      
                                                                 
 leaky_re_lu_58 (LeakyReLU)  (None, 1, 512)            0         
                                                                 
 dense_88 (Dense)            (None, 1, 512)            262656    
                                                                 
 leaky_re_lu_59 (LeakyReLU)  (None, 1, 512)            0         
                                                                 
 dense_89 (Dense)            (None, 1, 1024)           525312    
                                                                 
 reshape_13 (Reshape)        (None, 32, 32, 1)         0         
                                                                 
=================================================================
Total params: 1,740,801
Trainable params: 1,740,801
Non-trainable params: 0
_________________________________________________________________
Epoch 1/35
106/106 [==============================] - 11s 94ms/step - loss: 0.0354 - val_loss: 0.0301
Epoch 2/35
106/106 [==============================] - 9s 86ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 3/35
106/106 [==============================] - 9s 84ms/step - loss: 0.0291 - val_loss: 0.0302
Epoch 4/35
106/106 [==============================] - 9s 81ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 5/35
106/106 [==============================] - 8s 79ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 6/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 7/35
106/106 [==============================] - 8s 76ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 8/35
106/106 [==============================] - 8s 75ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 9/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 10/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 11/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 12/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0302
Epoch 13/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 14/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 15/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 16/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 17/35
106/106 [==============================] - 8s 74ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 18/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 19/35
106/106 [==============================] - 8s 75ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 20/35
106/106 [==============================] - 8s 78ms/step - loss: 0.0292 - val_loss: 0.0301
Epoch 21/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 22/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0299
Epoch 23/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 24/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 25/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 26/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 27/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 28/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 29/35
106/106 [==============================] - 8s 71ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 30/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301
Epoch 31/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 32/35
106/106 [==============================] - 8s 74ms/step - loss: 0.0290 - val_loss: 0.0301
Epoch 33/35
106/106 [==============================] - 8s 73ms/step - loss: 0.0290 - val_loss: 0.0300
Epoch 34/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0300
Epoch 35/35
106/106 [==============================] - 8s 72ms/step - loss: 0.0291 - val_loss: 0.0301

FYI - This is a one class autoencoder that is only trained on the normal class instance.

0

There are 0 best solutions below