"derivative for aten::linear_backward is not implemented" when calling backward() on mps in torch

757 Views Asked by At

I'm working on a GAN to generate sounds. I copied most of the code from the wavegan-pytorch github. I'm working on a MacBook with M2 core, so I wanted to shift the processing from cpu to gpu with mps. But when I call torch.Tensor.backward() on my loss I get an Error, that linear_backward is not implemented. I'm still pretty new to programming, is there a simple mistake, that I overlook, or is it just not possible to run the code on gpu? Here's my code:

real_signal = next(self.train_loader)

# need to add mixed signal and flag
noise = sample_noise(batch_size * generator_batch_size_factor)
generated = self.generator(noise)
#############################
# Calculating discriminator loss and updating discriminator
#############################
self.apply_zero_grad()
disc_cost, disc_wd = self.calculate_discriminator_loss(
    real_signal.data, generated.data
)
assert not (torch.isnan(disc_cost))
disc_cost.backward()
self.optimizer_d.step()

would be very glad for help. Let me know, if you need more info, I'm sorry in advance, if there's a simple solution, that I don't get, because I'm new to this.

Here is the code for the calculate_discriminator_loss() function:

def calculate_discriminator_loss(self, real, generated):
    disc_out_gen = self.discriminator(generated)
    disc_out_real = self.discriminator(real)

    alpha = torch.FloatTensor(batch_size * 2, 1, 1).uniform_(0, 1).to(device)
    alpha = alpha.expand(batch_size * 2, real.size(1), real.size(2))

    interpolated = (1 - alpha) * real.data + (alpha) * generated.data[:batch_size * 2]
    interpolated = Variable(interpolated, requires_grad=True)

    # calculate probability of interpolated examples
    prob_interpolated = self.discriminator(interpolated)
    grad_inputs = interpolated
    ones = torch.ones(prob_interpolated.size()).to(device)
    gradients = grad(
        outputs=prob_interpolated,
        inputs=grad_inputs,
        grad_outputs=ones,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    # calculate gradient penalty
    grad_penalty = (
        p_coeff
        * ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
    )
    assert not (torch.isnan(grad_penalty))
    assert not (torch.isnan(disc_out_gen.mean()))
    assert not (torch.isnan(disc_out_real.mean()))
    cost_wd = disc_out_gen.mean() - disc_out_real.mean()
    cost = cost_wd + grad_penalty
    return cost, cost_wd
1

There are 1 best solutions below

5
Robin van Hoorn On

Seeing you are implementing the discriminator loss calculation from a WGAN-GP, I thought I'd work out what was going wrong and improve what you have.

First, you were doing absolutely great, with some slight flaws here and there. The problem is indeed in the calculate_discriminator_loss function. Things to improve:

  1. Variable is deprecated in the most recent version of Pytorch. I'd recommend not to use it because it is unsupported.
  2. You can index the generated and real tensors without accessing the data attribute, like so: `generated[:batch_size * 2]
  3. I am not sure what you are trying to do with the batch_size * 2. Is the generated batch bigger than the real batch of data? I would advise to keep them the same size.
  4. PyTorch has an expand_as function which is really useful here (instead of expand and then to define the size of some tensor).
  5. When computing the gradients, you do not need retain_graph=True as you don't compute gradients twice.
  6. When computing gradients, you do not need only_inputs=True. It is deprecated, and the default setting is True.
  7. p_coeff and device are not variables defined in the function. Make sure to define them in the class, and then access them through self.p_coeff and self.device.

The following works when I run it:

def calculate_discriminator_loss(self, real, generated):
    assert real.shape == generated.shape
    disc_out_gen = self.discriminator(generated)
    disc_out_real = self.discriminator(real)

    alpha = torch.rand(self.batch_size, 1).to(self.device)
    alpha = alpha.expand_as(real)

    interpolated = (1 - alpha) * real + alpha * generated

    # calculate probability of interpolated examples
    prob_interpolated = self.discriminator(interpolated)
    ones = torch.ones(prob_interpolated.size()).to(self.device)
    gradients = grad(
        outputs=prob_interpolated,
        inputs=interpolated,
        grad_outputs=ones,
        create_graph=True)[0]

    # calculate gradient penalty
    grad_penalty = (
        torch.mean((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2)
    )

    cost_wd = disc_out_gen.mean() - disc_out_real.mean()
    cost = cost_wd + grad_penalty
    return cost, cost_wd

Cleaned up your code a bit as well to be more readable and removed the asserts.

Hope this helps.