I am trying to print the value of each of the intermediate gradients during backward pass of a model, using register backward hooks:
class func_NN(torch.nn.Module):
def __init__(self,) :
super().__init__()
self.a = torch.nn.Parameter(torch.rand(1))
self.b = torch.nn.Parameter(torch.rand(1))
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1,1)*inp)
sum_x = mul_x - self.b
return sum_x
# hook function
def backward_hook(module, grad_input, grad_output):
print("module: ", module)
print("inp: ", grad_input)
print("out: ", grad_output)
# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)
t_l = []
for i in range(2):
optim.zero_grad()
l = loss(y, foo.forward(inp=inp))
t_l.append(l.detach())
l.backward()
optim.step()
handle_.remove()
But this does not provide the desired result.
My objective is to print the gradients of the non-leaf nodes like sum_x
and mul_x
.
Please help.
Pytorch hooks are designed to grab gradients with respect to parameters. You cannot use them to grab gradients of intermediate tensors.
If you want to get the gradients of intermediate tensors, you need to save them to the model's state and apply
retain_grad
to them.