jax.debug.callback — JAX documentation (original) (raw)
jax.debug.callback#
jax.debug.callback(callback, *args, ordered=False, partitioned=False, **kwargs)[source]#
Calls a stageable Python callback.
For more explanation, see External Callbacks.
jax.debug.callback
enables you to pass in a Python function that can be called inside of a staged JAX program. A jax.debug.callback
follows existing JAX transformation pure operational semantics, which are therefore unaware of side-effects. This means the effect could be dropped, duplicated, or potentially reordered in the presence of higher-order primitives and transformations.
We want this behavior because we’d like jax.debug.callback
to be “innocuous”, i.e. we want these primitives to change the JAX computation as little as possible while revealing as much about them as possible, such as which parts of the computation are duplicated or dropped.
Parameters:
- callback (Callable [ ... , None ]) – A Python callable returning None.
- *args (Any) – The positional arguments to the callback.
- ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t. other ordered callbacks.
- partitioned (bool) – If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first.
- **kwargs (Any) – The keyword arguments to the callback.
Returns:
None
Return type:
None
See also
- jax.experimental.io_callback(): callback designed for impure functions.
- jax.pure_callback(): callback designed for pure functions.
- jax.debug.print(): callback designed for printing.