I find custom automatic differentiation capabilities (JVP, VJP) very useful in JAX, but am having a hard time applying it to higher order functions. A minimal example of this sort is as follows: given a higher order function:
def parent_func(x):
def child_func(y):
return x**2 * y
return child_func
I would like to define custom gradients of child_func with respect to x and y. What would be the correct syntax to achieve this?
Gradients in JAX are defined with respect to a function’s explicit inputs. Your
child_funcdoes not takexas an explicit input, so you cannot directly differentiatechild_funcwith respect tox. However, you could do so indirectly by calling it from another function that takesx. For example:Then if you wish, you could define standard custom derivative rules for
func_to_differentiate.