Pytorch compute gradients without autograd with respect to weights

44 Views Asked by At

I know the question seems like it answers itself (compute weight gradients without gradients?). But the issue I am trying to resolve is that I will need to change leaf variables before computing the backwards pass. Without no_grad pytorch will complain about in-place modifications.

  1. Define the model (it is split in two stages of a pipeline):
class Pipe1(nn.Module):
      def __init__(self):
        super(Pipe1, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

      def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        return x
 class Pipe2(nn.Module):
     def __init__(self):
        super(Pipe2, self).__init__()
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

     def forward(self, x):
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        
        x = self.fc2(x)
        return F.log_softmax(x)

  1. Make the pipeline:
pipe = [Pipe1(), Pipe2()]

  1. Run a micro-batch through the pipeline:
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(train_loader):
      output = pipe [1](pipe [0](data))
      
      loss = F.nll_loss(output, target)
      pipe[0].conv1.weight *= 3

And... what is four? Given the current loss I would ideally want to compute the derivatives with respect to the new updated weights. Is there any "proper" way of doing this in Pytorch? Would I need to compute the gradients myself? I would prefer if I wouldn't need to do that per type of layer.

The Jacobian computes the gradient with respect to the inputs, so that won't be useful.

I am well aware that the activations which are passed to the second stage of the pipeline (Pipe2) are computed on weights different than those on which I compute the gradients for on Pipe1. This is intended behaviour (see Pipemare)

Best I can come up with is the following (disgusting) solution

  1. Do not get gradient on first stage, keep gradient on first stage output (second stage input) and compute loss
 with torch.no_grad():
        inpt_tmp = pipe [0](data)
 inpt_tmp.requires_grad = True
 output = pipe (inpt_tmp)
 loss = F.nll_loss(output, target)
 loss.backward()
 optimizers[1].step()
  1. Update parameter, run new forward pass and use the output (with external gradient the gradient of the old output) to compute a new backwards pass:
with torch.no_grad():
   pipe[0].conv1.weight *= 3
tmp = pipelines[batch_idx % modulo][0](data)
tmp.backward(inpt_tmp.grad)

This requires 2 forward passes on the first stage.

0

There are 0 best solutions below