Convert for loop to jax.lax.scan

169 Views Asked by At

How does one convert the following (to accelerate compiling)? The for loop version works with jax.jit,

import functools
import jax
import jax.numpy as jnp

@functools.partial(jax.jit, static_argnums=0)
def func(n):

    p = 1
    x = jnp.arange(8)
    y = jnp.zeros((n,))

    for idx in range(n):
        y = y.at[idx].set(jnp.sum(x[::p]))
        p = 2*p

    return y

func(2)
# >> Array([28., 12.], dtype=float32)

but will return static start/stop/step errors when using scan

import numpy as np

def body(p, xi):

    y = jnp.sum(x[::p])

    p = 2*p

    return p, y

x = jnp.arange(8)

jax.lax.scan(body, 1, np.arange(2))
# >> IndexError: Array slice indices must have static start/stop/step ...
1

There are 1 best solutions below

2
On BEST ANSWER

The issue here is that within scan, the p variable represents a dynamic value, meaning that x[::p] is a dynamically-sized array, so the operation is not allowed in JAX transformations (see JAX sharp bits: dynamic shapes).

Often in such cases it's possible to replace approaches using dynamically-shaped intermediates with other approaches that compute the same thing using only use static arrays; in this case one thing you might do is replace this problematic line:

jnp.sum(x[::p])

with this, which does the same sum using only statically-sized arrays:

jnp.sum(x, where=jnp.arange(len(x)) % p == 0)

Using this idea, here's a version of your original function that uses scan:

import numpy as np

@functools.partial(jax.jit, static_argnums=0)
def func_scan(n):
    p = 1
    x = jnp.arange(8)
    y = jnp.zeros((n,))

    def body(carry, _):
      idx, y, p = carry
      y = y.at[idx].set(jnp.sum(x, where=jnp.arange(len(x)) % p == 0))
      return (idx + 1, y, 2 * p), None

    (i, y, p), _ = jax.lax.scan(body, (0, y, p), xs=None, length=n)
    return y

func_scan(2)
# Array([28., 12.], dtype=float32)