I know the question seems like it answers itself (compute weight gradients without gradients?). But the issue I am trying to resolve is that I will need to change leaf variables before computing the backwards pass. Without no_grad pytorch will complain about in-place modifications.
- Define the model (it is split in two stages of a pipeline):
class Pipe1(nn.Module):
def __init__(self):
super(Pipe1, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
return x
class Pipe2(nn.Module):
def __init__(self):
super(Pipe2, self).__init__()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x)
- Make the pipeline:
pipe = [Pipe1(), Pipe2()]
- Run a micro-batch through the pipeline:
with torch.no_grad():
for batch_idx, (data, target) in enumerate(train_loader):
output = pipe [1](pipe [0](data))
loss = F.nll_loss(output, target)
pipe[0].conv1.weight *= 3
And... what is four? Given the current loss I would ideally want to compute the derivatives with respect to the new updated weights. Is there any "proper" way of doing this in Pytorch? Would I need to compute the gradients myself? I would prefer if I wouldn't need to do that per type of layer.
The Jacobian computes the gradient with respect to the inputs, so that won't be useful.
I am well aware that the activations which are passed to the second stage of the pipeline (Pipe2) are computed on weights different than those on which I compute the gradients for on Pipe1. This is intended behaviour (see Pipemare)
Best I can come up with is the following (disgusting) solution
- Do not get gradient on first stage, keep gradient on first stage output (second stage input) and compute loss
with torch.no_grad():
inpt_tmp = pipe [0](data)
inpt_tmp.requires_grad = True
output = pipe (inpt_tmp)
loss = F.nll_loss(output, target)
loss.backward()
optimizers[1].step()
- Update parameter, run new forward pass and use the output (with external gradient the gradient of the old output) to compute a new backwards pass:
with torch.no_grad():
pipe[0].conv1.weight *= 3
tmp = pipelines[batch_idx % modulo][0](data)
tmp.backward(inpt_tmp.grad)
This requires 2 forward passes on the first stage.