Idiomatic ways to handle errors in JAX jitted functions

174 Views Asked by At

As the title states, I'd like to know what idiomatic methods are available to raise exceptions or handle errors in JAX jitted functions. The functional nature of JAX makes it unclear how to accomplish this.

The closest official documentation I could find is the jax.experimental.checkify module, but this wasn't very clear and seemed incomplete.

This Github comment claims that Python exceptions can be raised by using jax.debug.callback() and jax.lax.cond() functions. I attempted to do this, but an error is thrown during compilation. A minimum working example is below:

import jax
from jax import jit

def _raise(ex):
    raise ex


@jit
def error_if_positive(x):
    jax.lax.cond(
        x > 0,
        lambda : jax.debug.callback(_raise, ValueError("x is positive")),
        lambda : None,
    )

if __name__ == "__main__":

    error_if_positive(-1)

The abbreviated error statement:

TypeError: Value ValueError('x is positive') with type <class 'ValueError'> is not a valid JAX type
1

There are 1 best solutions below

2
On BEST ANSWER

You can use callbacks to raise errors, for example:

import jax
from jax import jit

def _raise_if_positive(x):
  if x > 0:
    raise ValueError("x is positive")

@jit
def error_if_positive(x):
  jax.debug.callback(_raise_if_positive, x)

if __name__ == "__main__":
  error_if_positive(-1)  # no error
  error_if_positive(1)
  # XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: x is positive

The reason your approach didn't work is becuase your error is raised at trace-time rather than at runtime, and both branches of the cond will always be traced.