I would like to use the package Oryx to invert an affine transformation written in JAX. The transformation maps x->y and depends on a set of adjustable parameters (which I call params). Specifically, the affine transformation is defined as:
import jax.numpy as jnp
def affine(params, x):
return x * params['scale'] + params['shift']
params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
y_out = affine(params, x_in)
I would like to invert affine wrt to input x as a function of params. Oryx has a function oryx.core.inverse to invert JAX functions. However, inverting a function with parameters, like this:
import oryx
oryx.core.inverse(affine)(params, y_out)
doesn't work (AssertionError: length mismatch: [1, 3]), presumably because inverse doesn't know that I want to invert y_out but not params.
What is the most elegant way to solve this problem for all possible values (i.e., as a function) of params using oryx.core.inverse?
I find the inverse docs not very illuminating.
Update:
Jakevdp gave an excellent suggestion for a given set of params. I've clarified the question to indicate that I am wondering how to define the inverse as a function of params.
You can do this by closing over the static parameters, for example using
partial:Edit: if you want a single inverted function to work for multiple values of
params, you will have to returnparamsin the output (otherwise, there's no way from a single output value to infer all three inputs). It might look something like this: