Issue with WGAN model - Negative losses for the discriminator and generator

28 Views Asked by At

I'm new to GANs, I have been trying to train a WGAN on 3d micro-CT images with one channel of shape (H, W, D). However, I got the discriminator and generator loss to be negative values. Can you please point out the reason for that. I have provided the discriminator and generator architectures and the training loop.

(https://i.stack.imgur.com/Gd5xw.png)](https://i.stack.imgur.com/iqzOG.png)

import torchvision.utils as vutils

CRTIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

# Number of training epochs
num_epochs = 5
dataloader = train_loader

# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0



# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: min -[Disc(real) - Gen(fake)]
        ###########################
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0) # this provide the size (number of images ) per batch
        for _ in range(CRTIC_ITERATIONS):
            netD.zero_grad()
            noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
            
            netD_real_output = netD(real_cpu).view(-1)
            fake = netG(noise)
            netD_fake_output = fake.detach().view(-1)
            errD = -(torch.mean(netD_real_output) - torch.mean(netD_fake_output))
            errD.backward()
            optimizerD.step()
            
            for p in netD.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
                                        
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
        fake = netG(noise)
        errG_fake_output = netD(fake).view(-1) # notice here we didnt detach the fake tensor as we want the weights of the generator to be updated ased on the outcome from the discriminator
        # Calculate G's loss based on this output
        errG = - torch.mean(errG_fake_output)
        # Calculate gradients for G
        errG.backward()

        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print(f"Epoch [ {epoch}/{num_epochs}] Batch {i}/{len(dataloader)} \
                    loss D: {errD:.4f}, loss G: {errG:.4f}")

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
       # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(fake)

        iters += 1

Output

Starting Training Loop...
Epoch [ 0/5] Batch 0/11                     loss D: -0.8412, loss G: -1.1208
Epoch [ 1/5] Batch 0/11                     loss D: -1.5598, loss G: -1.7733
Epoch [ 2/5] Batch 0/11                     loss D: -2.0055, loss G: -2.1531
0

There are 0 best solutions below