A faster Hessian vector product in PyTorch

553 Views Asked by At

I need to take a Hessian vector product of a loss w.r.t. model parameters a large number of times. It seems that there is no efficient way to do this and a for loop is always required, resulting in a large number of independent autograd.grad calls. My current implementation is given below, it is representative of my use case. Do note in the real case the collection of vectors v are not all known beforehand.

import torch
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 500), torch.nn.Tanh(), torch.nn.Linear(500, 1))
num_param = sum(p.numel() for p in model.parameters())

# Evaluate some loss on a random dataset
x = torch.rand((10000,1))
y = torch.rand((10000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()

# Calculate Jacobian of loss w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten

# Calculate Hessian vector product
start_time = time.time()
for i in range(10):
    v = torch.rand(num_param)
    HVP = torch.autograd.grad(J, list(model.parameters()), v, retain_graph=True)
print('Time per HVP: ', (time.time() - start_time)/10)

Which takes around 0.05 s per Hessian vector product on my machine. Is there a way to speed this up? Especially considering that the Hessian itself does not change in between calls.

0

There are 0 best solutions below