Stylegan2 Training NaN for gp and loss_critic resulting to black pictures

46 Views Asked by At

I'm trying a code from a guide and follow every step and when the time that I try to train the model with my own pictures which are in 128x128 the resulting process for the gp and loss_critic results to a NaN. So I am guessing there might be some problems in my data or the equations inside the code. Asking for more insights on this problem. Below are the lines from gp and loss_critic

gp = gradient_penalty(critic, real, fake, device=DEVICE)
loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp + (0.001 * torch.mean(critic_real ** 2))

and here is my training loop:

def train_fn(
    critic,
    gen,
    path_length_penalty,
    loader,
    opt_critic,
    opt_gen,
    opt_mapping_network,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        w     = get_w(cur_batch_size)
        noise = get_noise(cur_batch_size)
        with torch.cuda.amp.autocast():
            fake = gen(w, noise)
            critic_fake = critic(fake.detach())
            
            critic_real = critic(real)
            gp = gradient_penalty(critic, real, fake, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake)
        loss_gen = -torch.mean(gen_fake)

        if batch_idx % 16 == 0:
            plp = path_length_penalty(w, fake)
            if not torch.isnan(plp):
                loss_gen = loss_gen + plp

        mapping_network.zero_grad()
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        opt_mapping_network.step()

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

and here is the guide that I am following: https://blog.paperspace.com/implementation-stylegan2-from-scratch/#load-all-dependencies-we-need

I think it is the CUDA or GPU problem. Any insights or solutions?

0

There are 0 best solutions below