I am relatively new to PyTorch and trying to compute the Hessian of a very simple feedforward networks with respect to its weights. I am trying to get torch.autograd.functional.hessian to work. I have been digging the forums and since this is a relatively new function added to PyTorch, I am unable to find a whole lot of information on it. Here is my simple network architecture which is from some sample code on Kaggle on Mnist.
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.l1(x)
x = self.relu(x)
x = self.l3(x)
return F.log_softmax(x, dim = 1)
net = Network()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_func = nn.CrossEntropyLoss()
and I am running the NN for a bunch of epochs like:
for e in range(epochs):
for i in range(0, x.shape[0], batch_size):
x_mini = x[i:i + batch_size]
y_mini = y[i:i + batch_size]
x_var = Variable(x_mini)
y_var = Variable(y_mini)
optimizer.zero_grad()
net_out = net(x_var)
loss = loss_func(net_out, y_var)
loss.backward()
optimizer.step()
if i % 100 == 0:
loss_log.append(loss.data)
Then, I add all the parameters to a list and make a tensor out of it as below:
param_list = []
for param in net.parameters():
param_list.append(param.view(-1))
param_list = torch.cat(param_list)
Finally, I am trying to compute the Hessian of the converged network by running:
hessian = torch.autograd.functional.hessian(loss_func, param_list,create_graph=True)
but it gives me this error: TypeError: forward() missing 1 required positional argument: 'target'
Any help would be appreciated.
Computing the hessian with regard to the parameters of a model (as opposed to the inputs to the model) isn't really well-supported right now. There's some work being done on this at https://github.com/pytorch/pytorch/issues/49171 , but for the moment it's very inconvenient.
Your code has a few other problems -- where you're passing
loss_func, you should be passing a function that constructs the computation graph. Also, you never specify the input to the network or the target for the loss function.Here's some code that cheats a little bit to use the existing functional interface to compute the hessian of the model weights, and concatenates everything together to give the same form as what you were trying to do: