In the JAX documentation, custom derivatives for functions with a single output are covered. I'm wondering how to implement custom derivatives for functions with multiple outputs such as this one?
# want to define custom derivative of out_2 with respect to *args
def test_func(*args, **kwargs):
...
return out_1, out_2
You can define custom derivatives for functions with any number of inputs and outputs: just add the appropriate number of elements to the
primalsandtangentstuples in thecustom_jvprule. For example:Comparing this with the result computed using the standard non-custom derivative rule for the same function shows that the results are equivalent: