jax.debug module — JAX documentation (original) (raw)

jax.debug module#

Runtime value debugging utilities#

Compiled prints and breakpoints describes how to make use of JAX’s runtime value debugging features.

callback(callback, *args[, ordered, partitioned]) Calls a stageable Python callback.
print(fmt, *args[, ordered, partitioned]) Prints values and works in staged out JAX functions.
breakpoint(*[, backend, filter_frames, ...]) Enters a breakpoint at a point in a program.

Sharding debugging utilities#

Functions that enable inspecting and visualizing array shardings inside (and outside) staged functions.