jax.disable_jit — JAX documentation (original) (raw)
jax.disable_jit#
jax.disable_jit(disable=True)[source]#
Context manager that disables jit() behavior under its dynamic context.
For debugging, it is useful to have a mechanism that disables jit()everywhere in a dynamic context. Note that this not only disables explicit uses of jit() by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of body andcond functions passed to higher-level primitives like scan() andwhile_loop(), JIT used in implementations of jax.numpy functions, and any other case where jit() is used within an API’s implementation. Note however that even under disable_jit, individual primitive operations will still be compiled by XLA as in normal eager op-by-op execution.
Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be aShapedArray
instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:
import jax
@jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<int32[3]>with<DynamicJaxprTrace...> [5 7 9]
Here y
has been abstracted by jit() to a ShapedArray
, which represents an array with a fixed shape and type but an arbitrary value. The value of y
is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit()context manager:
import jax
with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]
Parameters:
disable (bool)