EDIT: embarrassingly my error was shuffling the data only and not the labels.
I was given an assignment to create an lstm autoEncoder in pytorch to reconstruct mnist images. next the assignment asked to modify the network to also allow for classification of the reconstructed images, an important part is that it should do the 2 tasks at the same time, reconstruction and classification of the reconstructed image, so the network should train on both the losses at the same time.
my implementation of the auto encoder is in this format:
def __init__(self, input_size, hidden_size, num_layers, output_size, epochs, optimizer, learning_rate, grad_clip, batch_size):
super(AE, self).__init__()
self.encoder = Encoder(input_size, hidden_size, num_layers)
self.decoder = Decoder(input_size, hidden_size, num_layers, output_size)
self.epochs = epochs
self.optimizer = optimizer
self.learning_rate = learning_rate
self.grad_clip = grad_clip
self.batch_size = batch_size
self.criterion = nn.MSELoss()
self.losses = []
the forward and the train methods work fine and when i run the network on the mnist datat set i get a fairly well reconstructed images with MSE loss averaging at around 1e-6.
I introduced the classifying elemnt in a separate class:
class AeWithClassifier(AE):
def __init__(self, input_size, hidden_size, num_layers, output_size, epochs, optimizer, learning_rate, grad_clip, batch_size, num_classes):
super(AeWithClassifier, self).__init__(input_size, hidden_size, num_layers, output_size, epochs, optimizer, learning_rate, grad_clip, batch_size)
self.classifier = nn.Sequential(
nn.Linear(output_size*output_size, num_classes))
self.classifier_criterion = nn.CrossEntropyLoss()
the methods are pretty straight forward but I will provide them:
def forward(self, x):
predictions = super().forward(x)
classifier_predictions = self.classifier(predictions.reshape(-1, 28*28))
return predictions, classifier_predictions
def train(self, x, y):
losses = []
optimizer = self.optimizer(self.parameters(), lr=self.learning_rate)
for epoch in range(self.epochs):
cur_loss = 0
batch_idx = 0
for batch_idx, x_batch in enumerate(x):
x_batch = x_batch.to(device)
y_batch = y[batch_idx*self.batch_size:(batch_idx+1)*self.batch_size]
optimizer.zero_grad()
predictions, classifier_predictions = self.forward(x_batch)
recon_loss = self.criterion(predictions, x_batch)
class_loss = self.classifier_criterion(classifier_predictions, y_batch)
cur_loss = loss = recon_loss + class_loss
loss.backward()
nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
optimizer.step()
losses.append(cur_loss.item())
print(f'Epoch: {epoch+1}/{self.epochs}, Loss: {cur_loss.item()}')
self.losses = losses
as you can see i calculated the loss as the sum of the loss of the reconstruction and the loss of the classification and then I use torch to perform grad calculation and optimization.
however in this format despite the fact that the network still reconstructs the images it fails to classify them properly with cross entropy loss doesn't decrease below ~2.3
Am I doing something wrong in the construction of the network? or is the problem in the training itself? I tried weighing the loss differently in order for the network to focus more on the classification task but it still doesn't improve at all.
Since your classification layer is a single linear layer that takes the reconstruction predictions as input and outputs the class predictions, the network has to find a hard balance between reconstruction quality and linear separability of the predicted images.
If you want to keep this architecture, you should try to give different weights to the reconstruction loss and classification loss, something like:
loss = recon_loss + 10*class_loss.You can try adding an activation, another layer and a softmax on top of the current linear classification layer for a better classification. However, it's probably better to change the architecture and generate the classification and the reconstruction predictions with different network branches from the latent representations, similarly to this: http://tech.octopus.energy/timeserio/_images/MNIST.svg.