jax.errors.UnexpectedTracerError only when using jax.debug.breakpoint()

163 Views Asked by At

My jax code runs fine but when I try to insert a breakpoint with jax.debug.breakpoint I get the error: jax.errors.UnexpectedTracerError.

I would expect this error to show up also without setting a breakpoint.

Is this intended behavior or is something weird happening? When using jax_checking_leaks none of the reported tracers seem to actually be leaked.

1

There are 1 best solutions below

0
On BEST ANSWER

There is currently a bug in jax.debug.breakpoint that can lead to spurious tracer leaks in some situations: see https://github.com/google/jax/issues/16732.

There's not any easy workaround at the moment, unfortunately, but hopefully the issue will be addressed soon.