jax.debug module — JAX documentation (original) (raw)
jax.debug
module#
Runtime value debugging utilities#
Compiled prints and breakpoints describes how to make use of JAX’s runtime value debugging features.
callback(callback, *args[, ordered, partitioned]) | Calls a stageable Python callback. |
---|---|
print(fmt, *args[, ordered, partitioned]) | Prints values and works in staged out JAX functions. |
breakpoint(*[, backend, filter_frames, ...]) | Enters a breakpoint at a point in a program. |
Sharding debugging utilities#
Functions that enable inspecting and visualizing array shardings inside (and outside) staged functions.