Jax scan with dynamic number of iterations

61 Views Asked by At

I wanted to perform a scan with a dynamic number of iterations. To accomplish that, I want to recompile the function each time when iters_to_do changes.

To avoid a huge slowdown, I'll be using a recompilation_cache but that's beside the point.

However, when I mark the argument in @partial(jax.jit) I'm still obtaining a concretization error:

@partial(jax.jit, static_argnums=(3))
def iterate_for_steps(self,
                        interim_thought: Array, 
                        mask: Array,
                        iters_to_do: int, 
                        input_arr: Array, 
                        key: PRNGKeyArray) -> Array:

    # These are constants
    input_arr = input_arr.astype(jnp.bfloat16)
    interim_thought = interim_thought.astype(jnp.bfloat16)
    
    def body_fun(i: int, thought: Array) -> Array:
        latent = jnp.concatenate([thought, input_arr], axis=-1).astype(jnp.bfloat16)
        latent = self.main_block(latent, input_arr, mask, key).astype(jnp.bfloat16)
        latent = jax.vmap(self.post_ln)(latent).astype(jnp.bfloat16)  # LN to keep scales tidy

        return latent
    
    iters_to_do = iters_to_do.astype(int).item()
    final_val = jax.lax.scan(body_fun, interim_thought, xs=None, length=iters_to_do)
    
    return final_val

Full traceback is here.

I've tried marking multiple arguments with @partial but to no avail.

I'm not sure how to approach debugging this - with a python debugger, I'm getting no help apart from the fact that its definitely a tracer.

MRE

from functools import partial
import jax
import jax.numpy as jnp

init = jnp.ones((5,))
iterations = jnp.array([1, 2, 3])

@partial(jax.jit, static_argnums=(0,))
def iterate_for_steps(iters: int):
    def body_fun(carry):
        return carry * 2
    
    iters = iters.astype(int)
    output = jax.lax.scan(body_fun, init, xs=None, length=iters)
    
    return output

print(jax.vmap(iterate_for_steps)(iterations))
2

There are 2 best solutions below

0
neel g On BEST ANSWER

One can use equinox's (internal as of right now) while_loop implementation which would also be able to handle a dynamic amount of iterations with checkpointing to reduce memory usage.

Note that this can be used as a drop-in replacement to jax's native while_loop. One can also use equinox's eqx.internal.scan if they wish to leverage similar checkpointing with scan.

5
jakevdp On

First of all, the number of iterations in a scan must be static. If you want something similar to scan that allows a dynamic number of iterations, you can take a look at while_loop.

Regarding your code: in isolation, your fix of marking iters_to_do as static using static_argnums is probably roughly the right idea, so long as you are passing a static int in this position when you call the function.

But the fact that you are calling the astype array method in your function (in iters_to_do.astype(int).item()) and getting a ConcretizationError rather than an AttributeError makes me think that the error you linked to is not coming from the code as pasted in your question.

To help address this discrepancy, I'd suggest trying to construct a minimal reproducible example of the problem you're having. Without that, any answer to your question is going to require too much guesswork regarding what code you're actually executing.