JAX debugging flags — JAX documentation (original) (raw)

Contents

JAX debugging flags#

JAX offers flags and context managers that enable catching errors more easily.

jax_debug_nans configuration option and context manager#

Summary: Enable the jax_debug_nans flag to automatically detect when NaNs are produced in jax.jit-compiled code.

jax_debug_nans is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jax.jit.

For code under an @jax.jit, the output of every @jax.jit function is checked and if a NaN is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of @jax.jit at a time.

There could be tricky situations that arise, like NaNs that only occur under a @jax.jit but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.

If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse.

Usage#

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by doing one of:

Example(s)#

import jax import jax.numpy as jnp import traceback jax.config.update("jax_debug_nans", True)

def f(x): w = 3 * jnp.square(x) return jnp.log(-w)

The stack trace is very long so only print a couple lines.

try: f(5.) except FloatingPointError as e: print(traceback.format_exc(limit=2))

The NaN generated was caught. By running %debug, we can get a post-mortem debugger. This also works with functions under @jax.jit, as the example below shows.

:tags: [raises-exception]

jax.jit(f)(5.)

When this code sees a NaN in the output of an @jax.jit function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with %debug to inspect all the values to figure out the error.

The jax.debug_nans context manager can be used to activate/deactivate NaN debugging. Since we activated it above with jax.config.update, let’s deactivate it:

with jax.debug_nans(False): print(jax.jit(f)(5.))

Strengths and limitations of jax_debug_nans#

Strengths#
Limitations#

jax_debug_infs configuration option and context manager#

jax_debug_infs works similarly to jax_debug_nans. jax_debug_infs often needs to be combined with jax_disable_jit, since Infs might not cascade to the output like NaNs. Alternatively, jax.experimental.checkify may be used to find Infs in intermediates.

Full documentation of jax_debug_infs is forthcoming.

jax_disable_jit configuration option and context manager#

Summary: Enable the jax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print and pdb

jax_disable_jit is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like jax.lax.cond and jax.lax.scan).

Usage#

You can disable JIT-compilation by:

Examples#

import jax jax.config.update("jax_disable_jit", True)

def f(x): y = jnp.log(x) if jnp.isnan(y): breakpoint() return y jax.jit(f)(-2.) # ==> Enters PDB breakpoint!

Strengths and limitations of jax_disable_jit#

Strengths#
Limitations#