This fails
import torch
def test1():
layer = nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test1()
with error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
8 x = 5 - torch.sum(layer(torch.ones(90)))
9 x.backward()
---> 10 test1()
11 # and this works as well
12
2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]
This works
import torch
def test2():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
del x #main change
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test2()
and this works as well
import torch
def test3():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
layer.weight = torch.nn.Parameter(layer.weight) #main change
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test3()
I encountered this when trying to implement a paper on model pruning (Temporal Neuron Variance Pruning). I believe this has something to do with the autograd graph, but I have am not sure what exactly is going on. I've already seen the link on pruning and got my code working using the 3rd snippet. I am now trying to figure out why 1 and 2 did not work. Is there some explanation for why these almost identical code snippets work or fail?
Major points I'd like to figure out -
- what is
TBackward0 - where is it defined
- where is the runtime error raised
- why is the compatibility with the old shape expected - especially when the grad has been modified correctly (I am assuming I have edited the tensors correctly because cases 2, 3 work)
- can I change something else (other than the 2 working cases) to make this work ?
Like you guessed, the issue is with the computational graph that gets created when you do backpropagation.
Let me explain the above point:
When you initialize a tensor in pytorch, it usually signals that the operations you perform on them should be tracked. When you do a forward pass, the functions for backward prop are set up and the graph is set.
In case 2, you are deleting the tensor and hence the entire process is reset -- the computation graph is reset. In case 3, you are clearly resetting the parameters.
The output tensor and the model parameters are connected to the graph.
If you want to clearly visualize where the TBackward0 function is, use torchviz to visualize the computational graph.