Mapping over multiple functions in JAX to speed up for loops

92 Views Asked by At

I am currently trying to speed up some JAX code I have been writing and see if I can replace two of my for loops with faster equivalents.

The first loop grabs a function from a list and evaluates it on the corresponding element of the input (3D) arrays xs and ys and modifies the input arrays before passing the next slice into the next function:

for j in range(N):
    dx, dy = mass_models[j].alpha(
        xs[j],
        ys[j],
        kwargs=kwargs[j]
    )
    weights = etas[j].reshape(-1, 1) # make this a column vector
    xs = xs - weights * dx
    ys = ys - weights * dy

I tried using jax.lax.scan because the input is updated for each loop run, but I couldn't figure out how to map each loop run to a different function. It is also tricky because I need to pass different kwargs to each function.

The next loop might be a bit easier, the first axis of the xs and ys from above are each mapped to a different function and the results are stacked (again with different kwargs for each function). At the moment I am using list comprehension:

results = jnp.stack([
    light_models[j].surface_brightness(
        xs[j],
        ys[j],
        kwargs[j],
        pixels_x_coord=pixels_x_coord[j],
        pixels_y_coord=pixels_y_coord[j]
    ) for j in range(N)
])

Is there a more JAX friendly way to write these kinds of loops?

Edit: For context xs and ys are JAX arrays, kwargs is a list of dictionaries (these are the parameters that are being estimated by the fitting code that uses this), pixels_*_coord are lists. mass_models and light_models are lists of class instances.

0

There are 0 best solutions below