I have a code that is using Variational autoencoder to generated new images that I try to plot them but in vain. The code is coded using pytorch as below :
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
batch_size=64,
shuffle=True)
dataiter = iter(data_loader)
images, labels = dataiter.next()
print(torch.min(images), torch.max(images))
# repeatedly reduce the size
class Autoencoder_Linear(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128), # (N, 784) -> (N, 128)
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3) # -> N, 3
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
# Input [-1, +1] -> use nn.Tanh
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
# N, 1, 28, 28
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1), # -> N, 16, 14, 14
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1), # -> N, 32, 7, 7
nn.ReLU(),
nn.Conv2d(32, 64, 7) # -> N, 64, 1, 1
)
# N , 64, 1, 1
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7), # -> N, 32, 7, 7
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # N, 16, 14, 14 (N,16,13,13 without output_padding)
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), # N, 1, 28, 28 (N,1,27,27)
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
# Note: nn.MaxPool2d -> use nn.MaxUnpool2d, or use different kernelsize, stride etc to compensate...
# Input [-1, +1] -> use nn.Tanh
model = Autoencoder()
#criterion = nn.MSELoss() #From MSE to KLDISTANCE
#criterion = loss_fn = nn.KLDivLoss(reduction='batchmean')
loss_fn = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-3,
weight_decay=1e-5)
# Point to training loop video
num_epochs = 1
outputs = []
for epoch in range(num_epochs):
for (img, _) in data_loader:
# img = img.reshape(-1, 28*28) # -> use for Autoencoder_Linear
recon = model(img)
# loss = criterion(recon, img)
loss = loss_fn(F.log_softmax(recon, dim=1), F.softmax(img, dim=1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon))
for k in range(0, num_epochs, 4):
plt.figure(figsize=(9, 2))
plt.gray()
imgs = outputs[k][1].detach().numpy()
recon = outputs[k][2].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
# item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1) # row_length + i + 1
# item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
#plt.show() # Indentation fixed here
print(f'img shape: {img.shape}')
print(f'img type: {img.dtype}')
print(f'recon shape: {recon.shape}')
print(f'recon shape: {recon.dtype}')
I tried to use RMSE as loss function, the code is showing a small loss function error and it shows the constructed image with no problem, but when using the KL distance loss function, it doesn't show the constructed images, it shows only the real images. I mean by the real images the ones that are chosen as dataset.