How does one convert the following (to accelerate compiling)?
The for
loop version works with jax.jit
,
import functools
import jax
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=0)
def func(n):
p = 1
x = jnp.arange(8)
y = jnp.zeros((n,))
for idx in range(n):
y = y.at[idx].set(jnp.sum(x[::p]))
p = 2*p
return y
func(2)
# >> Array([28., 12.], dtype=float32)
but will return static start/stop/step
errors when using scan
import numpy as np
def body(p, xi):
y = jnp.sum(x[::p])
p = 2*p
return p, y
x = jnp.arange(8)
jax.lax.scan(body, 1, np.arange(2))
# >> IndexError: Array slice indices must have static start/stop/step ...
The issue here is that within
scan
, thep
variable represents a dynamic value, meaning thatx[::p]
is a dynamically-sized array, so the operation is not allowed in JAX transformations (see JAX sharp bits: dynamic shapes).Often in such cases it's possible to replace approaches using dynamically-shaped intermediates with other approaches that compute the same thing using only use static arrays; in this case one thing you might do is replace this problematic line:
with this, which does the same sum using only statically-sized arrays:
Using this idea, here's a version of your original function that uses
scan
: