I am working on an differential equation solver written in JAX. A common workflow I come across is something like this:
import jax.numpy as jnp
from jax import jit
# Function to integrate.
@jit
def dxdt(t, x):
return -x**2
# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
return x + f(t, x) * dt
t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]
x_list = []
# initialize x.
x = 0.
for t in t_arr:
x_list.append(x)
x = integrator(f, t, x, dt)
x_arr = jnp.array(x_list)
My question is if there is a way to 'vectorize' that for-loop using JAX?
I recognize that jax.vmap() would not be appropriate here, since the variable x is being changed in each for-loop iteration. If there a more JAX-friendly approach to this workflow?
This sort of sequential operation, where each step is dependent on the last, is supported in JAX via
jax.lax.scan. Here's how you might do the equivalent of your computation usingscan: