I am currently trying to speed up some JAX code I have been writing and see if I can replace two of my for loops with faster equivalents.
The first loop grabs a function from a list and evaluates it on the corresponding element of the input (3D) arrays xs
and ys
and modifies the input arrays before passing the next slice into the next function:
for j in range(N):
dx, dy = mass_models[j].alpha(
xs[j],
ys[j],
kwargs=kwargs[j]
)
weights = etas[j].reshape(-1, 1) # make this a column vector
xs = xs - weights * dx
ys = ys - weights * dy
I tried using jax.lax.scan
because the input is updated for each loop run, but I couldn't figure out how to map each loop run to a different function. It is also tricky because I need to pass different kwargs
to each function.
The next loop might be a bit easier, the first axis of the xs
and ys
from above are each mapped to a different function and the results are stacked (again with different kwargs
for each function). At the moment I am using list comprehension:
results = jnp.stack([
light_models[j].surface_brightness(
xs[j],
ys[j],
kwargs[j],
pixels_x_coord=pixels_x_coord[j],
pixels_y_coord=pixels_y_coord[j]
) for j in range(N)
])
Is there a more JAX friendly way to write these kinds of loops?
Edit:
For context xs
and ys
are JAX arrays, kwargs
is a list of dictionaries (these are the parameters that are being estimated by the fitting code that uses this), pixels_*_coord
are lists. mass_models
and light_models
are lists of class instances.