Fast way to calculate Hessian matrix of model parameters in PyTorch

1.2k Views Asked by At

I want to calculate the Hessian matrix of a loss w.r.t. model parameters in PyTorch, but using torch.autograd.functional.hessian is not an option for me since it recomputes the model output and loss which I already have from previous calls. My current implementation is as follows:

import torch
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())

# Evaluate some loss on a random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()

''' Calculate Hessian '''
start = time.time()

# Allocate Hessian size
H = torch.zeros((num_param, num_param))

# Calculate Jacobian w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten

# Fill in Hessian
for i in range(num_param):
    result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
    H[i] = torch.cat([r.flatten() for r in result]) # flatten

print(time.time() - start)

Is there any way to do this faster? Perhaps without using the for loop, since it is calling autograd.grad for every single model variable.

1

There are 1 best solutions below

0
Thomas Wagenaar On

One way to make it faster is using functorch.hessian (based on this issue), however it has to recompute the loss everytime a Hessian is calculated (while I already have access to the loss). Nevertheless, i'll post it for those that are interested. I still think it is far too slow.

import torch
from functorch import hessian
from torch.nn.utils import _stateless
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())
names = list(n for n, _ in model.named_parameters())

# Create random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))

# Define loss function
def loss(params):
    y_hat = _stateless.functional_call(model, {n: p for n, p in zip(names, params)}, x)
    return ((y_hat - y)**2).mean()

# Calculate Hessian
hessian_func = hessian(loss)

start = time.time()

H = hessian_func(tuple(model.parameters()))
H = torch.cat([torch.cat([e.flatten() for e in Hpart]) for Hpart in H]) # flatten
H = H.reshape(num_param, num_param)

print(time.time() - start)