Computing the Hessian of a Simple NN in PyTorch wrt to Parameters

640 Views Asked by At

I am relatively new to PyTorch and trying to compute the Hessian of a very simple feedforward networks with respect to its weights. I am trying to get torch.autograd.functional.hessian to work. I have been digging the forums and since this is a relatively new function added to PyTorch, I am unable to find a whole lot of information on it. Here is my simple network architecture which is from some sample code on Kaggle on Mnist.

class Network(nn.Module):
    
    def __init__(self):
        super(Network, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l3(x)
        return F.log_softmax(x, dim = 1)
net = Network()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_func = nn.CrossEntropyLoss()

and I am running the NN for a bunch of epochs like:

for e in range(epochs):
    for i in range(0, x.shape[0], batch_size):
        x_mini = x[i:i + batch_size] 
        y_mini = y[i:i + batch_size] 
        x_var = Variable(x_mini)
        y_var = Variable(y_mini)
        optimizer.zero_grad()
        net_out = net(x_var)
        loss = loss_func(net_out, y_var)
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            loss_log.append(loss.data)

Then, I add all the parameters to a list and make a tensor out of it as below:

param_list = []
for param in net.parameters():
    param_list.append(param.view(-1))
param_list = torch.cat(param_list)

Finally, I am trying to compute the Hessian of the converged network by running:

hessian = torch.autograd.functional.hessian(loss_func, param_list,create_graph=True)

but it gives me this error: TypeError: forward() missing 1 required positional argument: 'target'

Any help would be appreciated.

1

There are 1 best solutions below

3
mlucy On

Computing the hessian with regard to the parameters of a model (as opposed to the inputs to the model) isn't really well-supported right now. There's some work being done on this at https://github.com/pytorch/pytorch/issues/49171 , but for the moment it's very inconvenient.

Your code has a few other problems -- where you're passing loss_func, you should be passing a function that constructs the computation graph. Also, you never specify the input to the network or the target for the loss function.

Here's some code that cheats a little bit to use the existing functional interface to compute the hessian of the model weights, and concatenates everything together to give the same form as what you were trying to do:

# Pick a random input to the network                             
src = torch.rand(1, 2)                                           
# Say our target for our loss is all ones                        
dst = torch.ones(1, dtype=torch.long)                            
                                                                 
keys = list(net.state_dict().keys())                             
parameters = list(net.parameters())                              
sizes = [x.view(-1).shape[0] for x in parameters]                
ndims = sum(sizes)                                               
                                                                 
def hessian_hack(*params):                                       
    for i in range(len(keys)):                                   
        path = keys[i].split('.')                                
        cur = net                                                
        for f in range(0, len(path)-1):                          
            cur = net.__getattr__(path[f])                       
        cur.__delattr__(path[-1])                                
        cur.__setattr__(path[-1], params[i])                     
    return loss_func(net(src), dst)                              
                                                                 
# sub_hessians[i][f] is the hessian of parameter i vs parameter f
sub_hessians = torch.autograd.functional.hessian(                
    hessian_hack,                                                
    tuple(parameters),                                           
    create_graph=True)                                           
                                                                 
# We can combine them all into a nice big hessian.               
hessian = torch.cat([                                            
        torch.cat([                                              
            sub_hessians[i][f].reshape(sizes[i], sizes[f])       
            for f in range(len(sub_hessians[i]))                 
        ], axis=1)                                               
    for i in range(len(sub_hessians))                            
], axis=0)                                                       
print(hessian)