So I have been playing around with implementing an optimizer that leverages the hessian-vector product in it's update rule inside of torch --- I am keen on using the existing VHP for the same, but it's really awkward API makes it impossible to be used inside of the step method when defining a custom optimizer class in torch.
Here are the problems that I currently have with the existing VHP interface:
- One typically does not pass in any argument to the
stepmethod when making the call tooptimizer.step(), except in the case of algorithms likeL-BFGSwhich require aclosureargument since they re-evaluate the loss function multiple times every iteration. Given the API for the VHP, a typical call to the method would look something like:VHP(loss_func, (model_output_tensor, gt_labels/values_tensor), v)(I have gripes about this signature as well, which I explain in the next point). Butstepdoes not typically have access to any of these arguments!stepdoesn't have any idea aboutloss_func, model_output_tensor, gt_labels--- so just to get it to work I decided to dump all of these arguments inside ofclosure(horrible choice, but that seemed to be the only way out) but then... - It seems like the way
VHPtreats theinputstuple is as separate inputs to the function that are to be evaluated independently (that is why they explicitly imposelen(inputs) == len(v)), but loss functions almost always take in vector inputs and I have no clue how one could pass in two vector inputs in a single tensor (except maybe by vertically stacking the inputs, but I am unsure ifVHPwould parse the inputs that way).
TLDR: The API for torch's VHP method is really hard to work with when defining a custom optimizer in torch --- are there any interesting alternatives for computing the Hessian-Vector-Product that do not run into the pitfalls that I mentioned above? As far as I am aware JAX also has an identical API for their own HVP method.
Thank you for your time!