Print intermediate gradient values during backward pass in Pytorch using hooks

170 Views Asked by At

I am trying to print the value of each of the intermediate gradients during backward pass of a model, using register backward hooks:

class func_NN(torch.nn.Module):
    def __init__(self,) :
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1,1)*inp)
        sum_x = mul_x - self.b
        return sum_x

# hook function
def backward_hook(module, grad_input, grad_output):
    print("module: ", module)
    print("inp: ", grad_input)
    print("out: ", grad_output) 

# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)

t_l = []
for i in range(2):
    optim.zero_grad()
    l = loss(y, foo.forward(inp=inp))
    t_l.append(l.detach())
    l.backward()
    optim.step()
handle_.remove()

But this does not provide the desired result.

My objective is to print the gradients of the non-leaf nodes like sum_x and mul_x. Please help.

1

There are 1 best solutions below

0
On

Pytorch hooks are designed to grab gradients with respect to parameters. You cannot use them to grab gradients of intermediate tensors.

If you want to get the gradients of intermediate tensors, you need to save them to the model's state and apply retain_grad to them.

class func_NN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1, 1) * inp)
        sum_x = mul_x - self.b

        # Retain gradients for intermediate variables
        mul_x.retain_grad()
        sum_x.retain_grad()

        # Store references to the intermediate tensors
        self.mul_x = mul_x
        self.sum_x = sum_x

        return sum_x

a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
loss = torch.nn.MSELoss()


l = loss(y, foo.forward(inp=inp))
l.backward()

print(foo.mul_x.grad)
print(foo.sum_x.grad)