How can I use torch's `VHP` routine inside of the `step` method when defining a custom optimizer?

31 Views Asked by At

So I have been playing around with implementing an optimizer that leverages the hessian-vector product in it's update rule inside of torch --- I am keen on using the existing VHP for the same, but it's really awkward API makes it impossible to be used inside of the step method when defining a custom optimizer class in torch.

Here are the problems that I currently have with the existing VHP interface:

  1. One typically does not pass in any argument to the step method when making the call to optimizer.step(), except in the case of algorithms like L-BFGS which require a closure argument since they re-evaluate the loss function multiple times every iteration. Given the API for the VHP, a typical call to the method would look something like: VHP(loss_func, (model_output_tensor, gt_labels/values_tensor), v) (I have gripes about this signature as well, which I explain in the next point). But step does not typically have access to any of these arguments! step doesn't have any idea about loss_func, model_output_tensor, gt_labels --- so just to get it to work I decided to dump all of these arguments inside of closure (horrible choice, but that seemed to be the only way out) but then...
  2. It seems like the way VHP treats the inputs tuple is as separate inputs to the function that are to be evaluated independently (that is why they explicitly impose len(inputs) == len(v)), but loss functions almost always take in vector inputs and I have no clue how one could pass in two vector inputs in a single tensor (except maybe by vertically stacking the inputs, but I am unsure if VHP would parse the inputs that way).

TLDR: The API for torch's VHP method is really hard to work with when defining a custom optimizer in torch --- are there any interesting alternatives for computing the Hessian-Vector-Product that do not run into the pitfalls that I mentioned above? As far as I am aware JAX also has an identical API for their own HVP method.

Thank you for your time!

0

There are 0 best solutions below