Custom JVP and VJP for higher order functions in JAX

91 Views Asked by At

I find custom automatic differentiation capabilities (JVP, VJP) very useful in JAX, but am having a hard time applying it to higher order functions. A minimal example of this sort is as follows: given a higher order function:

def parent_func(x):
    def child_func(y):
        return x**2 * y
    return child_func

I would like to define custom gradients of child_func with respect to x and y. What would be the correct syntax to achieve this?

1

There are 1 best solutions below

0
jakevdp On BEST ANSWER

Gradients in JAX are defined with respect to a function’s explicit inputs. Your child_func does not take x as an explicit input, so you cannot directly differentiate child_func with respect to x. However, you could do so indirectly by calling it from another function that takes x. For example:

def func_to_differentiate(x, y):
  child_func = parent_func(x)
  return child_func(y)

jax.grad(func_to_differentiate, argnums=0)(1.0, 1.0)  # 2.0

Then if you wish, you could define standard custom derivative rules for func_to_differentiate.