I want to calculate the Hessian matrix of a loss w.r.t. model parameters in PyTorch, but using torch.autograd.functional.hessian is not an option for me since it recomputes the model output and loss which I already have from previous calls. My current implementation is as follows:
import torch
import time
# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())
# Evaluate some loss on a random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()
''' Calculate Hessian '''
start = time.time()
# Allocate Hessian size
H = torch.zeros((num_param, num_param))
# Calculate Jacobian w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten
# Fill in Hessian
for i in range(num_param):
result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
H[i] = torch.cat([r.flatten() for r in result]) # flatten
print(time.time() - start)
Is there any way to do this faster? Perhaps without using the for loop, since it is calling autograd.grad for every single model variable.
One way to make it faster is using
functorch.hessian(based on this issue), however it has to recompute the loss everytime a Hessian is calculated (while I already have access to the loss). Nevertheless, i'll post it for those that are interested. I still think it is far too slow.