I have a JAX function cart_deriv() which takes another function f and returns the Cartesian derivative of f, implemented as follows:
@partial(custom_vjp, nondiff_argnums=0)
def cart_deriv(f: Callable[..., float],
l: int,
R: Array
) -> Array:
df = lambda R: f(l, jnp.dot(R, R))
for i in range(l):
df = jacrev(df)
return df(R)
def cart_deriv_fwd(f, l, primal):
primal_out = cart_deriv(f, l, primal)
residual = cart_deriv(f, l+1, primal) ## just a test
return primal_out, residual
def cart_deriv_bwd(f, residual, cotangent):
cotangent_out = jnp.ones(3) ## just a test
return (None, cotangent_out)
cart_deriv.defvjp(cart_deriv_fwd, cart_deriv_bwd)
if __name__ == "__main__":
def test_func(l, r2):
return l + r2
primal_out, f_vjp = vjp(cart_deriv,
jax.tree_util.Partial(test_func),
2,
jnp.array([1., 2., 3.])
)
cotangent = jnp.ones((3, 3))
cotangent_out = f_vjp(cotangent)
print(cotangent_out[1].shape)
However this code produces the error:
TypeError: cart_deriv_bwd() missing 1 required positional argument: 'cotangent'
I have checked that the syntax agrees with that in the documentation. I'm wondering why the argument cotangent is not recognized by vjp, and how to fix this error?
The issue is that
nondiff_argnumsis expected to be a sequence:With this properly defined, it's better to avoid wrapping the function in
Partial, and just pass it as a static argument by closing over it in thevjpcall: