I have a PyTree params (in my case a nested dictionary) containing my parameters of a neural network. My goal is to compute the diagonal entries of the Hessian of a loss function with respect to the parameters and store it in a PyTree of the same structure as the parameters.
When I call jax.hessian(loss_fn)(params, data), I get a (as expected) an even more nested dictionary with the full Hessian.
How can I transform this dictionary to get the desired PyTree with diagonal entries?
To be more concrete: Lets say I have only 1 layer in my network and paramsis given by
params:
'linear':
'w': DeviceArray() of shape [5 x 1]
'b': DeviceArray() of shape [1]
The returned Hessian has the keys and shape given by
hessian:
'linear':
'b':
'linear':
'b': (1, 1),
'w': (1, 5, 1),
'w':
'linear':
'b': (5, 1, 1),
'w': (5, 1, 5, 1)
As far as I understand it, I need the entries
jnp.diag(hessian['linear']['b']['linear']['b'])
as the diagonal hessian for the bias and
jnp.diag(jnp.squeeze(hessian['linear']['w']['linear']['w']))
as the diagonal hessian for the weights. (However, the squeeze may only work for 1 dim outputs...)
How can I automate this transformation in order to work for more complex models with multiple layers?
I know that this does not scale to huge networks, I need it for testing purposes of optimizers.
I ran into the exact same problem. Unfortunately, working with Pytrees in Jax can be awkward. I was also looking at a way to construct the diagonal Hessian entry-for-entry, since that could yield a practical method.
I now have the following:
When I try this out on the Hessian of a very simple MLP:
Then, the function returns:
Upon visual inspection, you can see that the returned elements are indeed the diagonal elements of
hessiancast to the canonical structure ofparams.Funnily enough, for the Gauss-Newton approximation to the Hessian the procedure is much simpler. Simply take the element-wise square of the Jacobians :).