What is the most efficient implementation of a scalable autonomous tridiagonal system using JAX?
import functools as ft
import jax as jx
import jax.numpy as jnp
import jax.random as jrn
import jax.lax as jlx
def make_T(m):
# Create a psuedo-random tridiagonal Jacobian and store band
T = jnp.zeros((3,m), dtype='f8')
T = T.at[0, 1: ].set(jrn.normal(jrn.PRNGKey(0), shape=(m-1,)))
T = T.at[1, : ].set(jrn.normal(jrn.PRNGKey(1), shape=(m ,)))
T = T.at[2, :-1].set(jrn.normal(jrn.PRNGKey(2), shape=(m-1,)))
return T
def make_y(m):
# Create a pseudo-random state array
y = jrn.normal(jrn.PRNGKey(3), shape=(m ,))
return y
def calc_f_base(y, T):
# Calculate the rate given the current state
f = T[1,:]*y
f = f.at[ 1: ].set(f[ 1: ]+T[0, 1: ]*y[ :-1])
f = f.at[ :-1].set(f[ :-1]+T[2, :-1]*y[ 1: ])
return f
m = 2**22 # potentially exhausts resources
T = make_T(m)
y = make_y(m)
calc_f = ft.partial(calc_f_base, T=T)
Using jax.jacrev or jax.jacfwd will generate a full Jacobian which limits the size of the system.
One attempt to overcome this limitation is as follows
@ft.partial(jx.jit, static_argnums=(0,))
def calc_jacfwd_trid(calc_f, y):
# Determine the Jacobian (forward-mode) tridiagonal band
def scan_body(carry, i):
t, T = carry
t = t.at[i ].set(1.0)
f, dfy = jx.jvp(calc_f, (y,), (t,))
T = T.at[2,i-1].set(dfy[i-1])
T = T.at[1,i ].set(dfy[i ])
T = T.at[0,i+1].set(dfy[i+1])
t = t.at[i-1].set(0.0)
return (t, T), None
# Initialise
m = y.size
t = jnp.zeros_like(y)
T = jnp.zeros((3,m), dtype=y.dtype)
# Differentiate wrt y[0]
t = t.at[0].set(1.0)
f, dfy = jx.jvp(calc_f, (y,), (t,))
idxs = jnp.array([1,0]), jnp.array([0,1])
T = T.at[idxs].set(dfy[0:2])
# Differentiate wrt y[1:-1]
(t, T), empty = jlx.scan(scan_body, (t,T), jnp.arange(1,m-1))
# Differentiate wrt y[-1]
t = t.at[m-2:].set(jnp.array([0.0,1.0]))
f, dfy = jx.jvp(calc_f, (y,), (t,))
idxs = jnp.array([2,1]), jnp.array([m-2,m-1])
T = T.at[idxs].set(dfy[-2:])
return T
which permits
T = jacfwd_trid(calc_f, y)
df = jrn.normal(jrn.PRNGKey(4), shape=y.shape)
dx = jlx.linalg.tridiagonal_solve(*T,df[:,None]).flatten()
Is there a better approach and/or can the time complexity of calc_jacfwd_trid be reduced further?
EDIT The following implementation is more compact, but run times are slightly slower
@ft.partial(jx.jit, static_argnums=(0,))
def calc_jacfwd_trid_map(calc_f, y):
# Determine the Jacobian (forward-mode) tridiagonal band with lax map
def map_body(i, t):
t = t.at[i-1].set(0.0)
f, dfy = jx.jvp(calc_f, (y,), (t,))
im1 = jnp.where(i > 0, i-1, 0)
Ti = jlx.dynamic_slice(dfy, (im1,), (3,))
Ti = jnp.where(i > 0, Ti, jnp.roll(Ti, shift=+1))
Ti = jnp.where(i < m-1, Ti, jnp.roll(Ti, shift=-1))
t = t.at[i ].set(1.0)
return Ti
# Initialise
m = y.size
t = jnp.zeros_like(y)
# Differentiate wrt y[:]
T = jlx.map(lambda i : map_body(i, t=t), jnp.arange(m))
# Correct the orientation of T
T = T.transpose()
T = jnp.flip(T, axis=0)
T = T.at[0,:].set(jnp.roll(T[0,:], shift=+1))
T = T.at[2,:].set(jnp.roll(T[2,:], shift=-1))
return T