JAX Tridiagonal Jacobians

240 Views Asked by At

What is the most efficient implementation of a scalable autonomous tridiagonal system using JAX?

import functools as ft
import jax as jx
import jax.numpy as jnp
import jax.random as jrn
import jax.lax as jlx

def make_T(m):
    # Create a psuedo-random tridiagonal Jacobian and store band
    T = jnp.zeros((3,m), dtype='f8')
    T = T.at[0, 1:  ].set(jrn.normal(jrn.PRNGKey(0), shape=(m-1,)))
    T = T.at[1,  :  ].set(jrn.normal(jrn.PRNGKey(1), shape=(m  ,)))
    T = T.at[2,  :-1].set(jrn.normal(jrn.PRNGKey(2), shape=(m-1,)))
    return T

def make_y(m):
    # Create a pseudo-random state array
    y = jrn.normal(jrn.PRNGKey(3), shape=(m  ,))
    return y

def calc_f_base(y, T):
    # Calculate the rate given the current state
    f = T[1,:]*y
    f = f.at[ 1:  ].set(f[ 1:  ]+T[0, 1:  ]*y[  :-1])
    f = f.at[  :-1].set(f[  :-1]+T[2,  :-1]*y[ 1:  ])
    return f

m = 2**22 # potentially exhausts resources
T = make_T(m)
y = make_y(m)

calc_f = ft.partial(calc_f_base, T=T)

Using jax.jacrev or jax.jacfwd will generate a full Jacobian which limits the size of the system.

One attempt to overcome this limitation is as follows

@ft.partial(jx.jit, static_argnums=(0,))
def calc_jacfwd_trid(calc_f, y):
    # Determine the Jacobian (forward-mode) tridiagonal band

    def scan_body(carry, i):
        t, T = carry
        t = t.at[i  ].set(1.0)
        
        f, dfy = jx.jvp(calc_f, (y,), (t,))
        
        T = T.at[2,i-1].set(dfy[i-1])
        T = T.at[1,i  ].set(dfy[i  ])
        T = T.at[0,i+1].set(dfy[i+1])

        t = t.at[i-1].set(0.0)

        return (t, T), None

    # Initialise
    m = y.size
    t = jnp.zeros_like(y)
    T = jnp.zeros((3,m), dtype=y.dtype)

    # Differentiate wrt y[0]
    t = t.at[0].set(1.0)
    f, dfy = jx.jvp(calc_f, (y,), (t,))
    idxs = jnp.array([1,0]), jnp.array([0,1])
    T = T.at[idxs].set(dfy[0:2])

    # Differentiate wrt y[1:-1]
    (t, T), empty = jlx.scan(scan_body, (t,T), jnp.arange(1,m-1))

    # Differentiate wrt y[-1]
    t = t.at[m-2:].set(jnp.array([0.0,1.0]))
    f, dfy = jx.jvp(calc_f, (y,), (t,))
    idxs = jnp.array([2,1]), jnp.array([m-2,m-1])
    T = T.at[idxs].set(dfy[-2:])

    return T

which permits

T = jacfwd_trid(calc_f, y)

df = jrn.normal(jrn.PRNGKey(4), shape=y.shape)
dx = jlx.linalg.tridiagonal_solve(*T,df[:,None]).flatten()

Is there a better approach and/or can the time complexity of calc_jacfwd_trid be reduced further?

EDIT The following implementation is more compact, but run times are slightly slower

@ft.partial(jx.jit, static_argnums=(0,))
def calc_jacfwd_trid_map(calc_f, y):
    # Determine the Jacobian (forward-mode) tridiagonal band with lax map

    def map_body(i, t):

        t = t.at[i-1].set(0.0)
        
        f, dfy = jx.jvp(calc_f, (y,), (t,))

        im1 = jnp.where(i > 0, i-1, 0)
        Ti = jlx.dynamic_slice(dfy, (im1,), (3,))
        Ti = jnp.where(i >   0, Ti, jnp.roll(Ti, shift=+1))
        Ti = jnp.where(i < m-1, Ti, jnp.roll(Ti, shift=-1))

        t = t.at[i  ].set(1.0)

        return Ti

    # Initialise
    m = y.size
    t = jnp.zeros_like(y)

    # Differentiate wrt y[:]
    T = jlx.map(lambda i : map_body(i, t=t), jnp.arange(m))

    # Correct the orientation of T
    T = T.transpose()
    T = jnp.flip(T, axis=0)
    T = T.at[0,:].set(jnp.roll(T[0,:], shift=+1))
    T = T.at[2,:].set(jnp.roll(T[2,:], shift=-1))

    return T
0

There are 0 best solutions below