The exact meaning of n_jitted_steps=5

52 Views Asked by At

I have tried to run the code. Here, there is a command called n_jitted_steps=5, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be difficult to understand. However, I have tried the following command here in Colab, where the relevant cell is

@jax.jit(n_jitted_steps=5)
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
    return loss
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

Obviously, this creates an error. However, I wonder

  1. Is the function of n_jitted_steps=5 to send run five steps in one go, probably similar to loop unrolling?
  2. If that is the case, what is the correct way to use it?

Thanks in advance.

1

There are 1 best solutions below

1
On

In the repository README the authors say the following:

When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via config.training.n_jitted_steps

Note that n_jitted_steps is a parameter defined by the score_sde repository. In the notebook you link to, it seems like you're using flax, and not using any code from score_sde. Given that, I don't think your question is meaningful, because there is no equivalent of the n_jitted_steps parameter in flax.