I'm new to GANs, I have been trying to train a WGAN on 3d micro-CT images with one channel of shape (H, W, D). However, I got the discriminator and generator loss to be negative values. Can you please point out the reason for that. I have provided the discriminator and generator architectures and the training loop.
(https://i.stack.imgur.com/Gd5xw.png)](https://i.stack.imgur.com/iqzOG.png)
import torchvision.utils as vutils
CRTIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01
# Number of training epochs
num_epochs = 5
dataloader = train_loader
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: min -[Disc(real) - Gen(fake)]
###########################
real_cpu = data[0].to(device)
b_size = real_cpu.size(0) # this provide the size (number of images ) per batch
for _ in range(CRTIC_ITERATIONS):
netD.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
netD_real_output = netD(real_cpu).view(-1)
fake = netG(noise)
netD_fake_output = fake.detach().view(-1)
errD = -(torch.mean(netD_real_output) - torch.mean(netD_fake_output))
errD.backward()
optimizerD.step()
for p in netD.parameters():
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
fake = netG(noise)
errG_fake_output = netD(fake).view(-1) # notice here we didnt detach the fake tensor as we want the weights of the generator to be updated ased on the outcome from the discriminator
# Calculate G's loss based on this output
errG = - torch.mean(errG_fake_output)
# Calculate gradients for G
errG.backward()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print(f"Epoch [ {epoch}/{num_epochs}] Batch {i}/{len(dataloader)} \
loss D: {errD:.4f}, loss G: {errG:.4f}")
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(fake)
iters += 1
Output
Starting Training Loop...
Epoch [ 0/5] Batch 0/11 loss D: -0.8412, loss G: -1.1208
Epoch [ 1/5] Batch 0/11 loss D: -1.5598, loss G: -1.7733
Epoch [ 2/5] Batch 0/11 loss D: -2.0055, loss G: -2.1531