JAX `custom_vjp` for functions with multiple outputs

48 Views Asked by At

In the JAX documentation, custom derivatives for functions with a single output are covered. I'm wondering how to implement custom derivatives for functions with multiple outputs such as this one?

# want to define custom derivative of out_2 with respect to *args
def test_func(*args, **kwargs):
    ...
    return out_1, out_2
1

There are 1 best solutions below

0
jakevdp On BEST ANSWER

You can define custom derivatives for functions with any number of inputs and outputs: just add the appropriate number of elements to the primals and tangents tuples in the custom_jvp rule. For example:

import jax
import jax.numpy as jnp

@jax.custom_jvp
def f(x, y):
  return x * y, x / y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primals_out = f(x, y)
  tangents_out = (x_dot * y + y_dot * x, 
                  x_dot / y - y_dot * x / y ** 2)
  return primals_out, tangents_out

x = jnp.float32(0.5)
y = jnp.float32(2.0)

jax.jacobian(f, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
#  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))

Comparing this with the result computed using the standard non-custom derivative rule for the same function shows that the results are equivalent:

def f2(x, y):
  return x * y, x / y

jax.jacobian(f2, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
#  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))