I have tried to run the code. Here, there is a command called n_jitted_steps=5, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be difficult to understand. However, I have tried the following command here in Colab, where the relevant cell is
@jax.jit(n_jitted_steps=5)
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
    return loss
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state
Obviously, this creates an error. However, I wonder
- Is the function of 
n_jitted_steps=5to send run five steps in one go, probably similar to loop unrolling? - If that is the case, what is the correct way to use it?
 
Thanks in advance.
                        
In the repository README the authors say the following:
Note that
n_jitted_stepsis a parameter defined by thescore_sderepository. In the notebook you link to, it seems like you're usingflax, and not using any code fromscore_sde. Given that, I don't think your question is meaningful, because there is no equivalent of then_jitted_stepsparameter inflax.