jax package — JAX documentation (original) (raw)
Public API: jax
package#
Subpackages#
- jax.numpy module
- jax.scipy module
- jax.lax module
- jax.random module
- jax.sharding module
- jax.debug module
- jax.dlpack module
- jax.distributed module
- jax.dtypes module
- jax.ffi module
- jax.extend.ffi module (deprecated)
- jax.flatten_util module
- jax.image module
- jax.nn module
- jax.ops module
- jax.profiler module
- jax.stages module
- jax.test_util module
- jax.tree module
- jax.tree_util module
- jax.typing module
- jax.export module
- jax.extend module
- jax.example_libraries module
- jax.experimental module
Configuration#
config | |
---|---|
check_tracer_leaks | Context manager for jax_check_tracer_leaks config option. |
checking_leaks | Context manager for jax_check_tracer_leaks config option. |
debug_nans | Context manager for jax_debug_nans config option. |
debug_infs | Context manager for jax_debug_infs config option. |
default_device | Context manager for jax_default_device config option. |
default_matmul_precision | Context manager for jax_default_matmul_precision config option. |
default_prng_impl | Context manager for jax_default_prng_impl config option. |
enable_checks | Context manager for jax_enable_checks config option. |
enable_custom_prng | Context manager for jax_enable_custom_prng config option (transient). |
enable_custom_vjp_by_custom_transpose | Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient). |
log_compiles | Context manager for jax_log_compiles config option. |
numpy_rank_promotion | Context manager for jax_numpy_rank_promotion config option. |
transfer_guard(new_val) | A contextmanager to control the transfer guard level for all transfers. |
Just-in-time compilation (jit
)#
jit(fun[, in_shardings, out_shardings, ...]) | Sets up fun for just-in-time compilation with XLA. |
---|---|
disable_jit([disable]) | Context manager that disables jit() behavior under its dynamic context. |
ensure_compile_time_eval() | Context manager to ensure evaluation at trace/compile time (or error). |
make_jaxpr([axis_env, return_shape, ...]) | Creates a function that produces its jaxpr given example args. |
eval_shape(fun, *args, **kwargs) | Compute the shape/dtype of fun without any FLOPs. |
ShapeDtypeStruct(shape, dtype, *[, ...]) | A container for the shape, dtype, and other static attributes of an array. |
device_put(x[, device, src, donate, may_alias]) | Transfers x to device. |
device_get(x) | Transfer x to host. |
default_backend() | Returns the platform name of the default XLA backend. |
named_call(fun, *[, name]) | Adds a user specified name to a function when staging out JAX computations. |
named_scope(name) | A context manager that adds a user specified name to the JAX name stack. |
block_until_ready(x) | Tries to call a block_until_ready method on pytree leaves. |
make_mesh(axis_shapes, axis_names, *[, ...]) | Creates an efficient mesh with the shape and axis names specified. |
Automatic differentiation#
grad(fun[, argnums, has_aux, holomorphic, ...]) | Creates a function that evaluates the gradient of fun. |
---|---|
value_and_grad(fun[, argnums, has_aux, ...]) | Create a function that evaluates both fun and the gradient of fun. |
jacobian(fun[, argnums, has_aux, ...]) | Alias of jax.jacrev(). |
jacfwd(fun[, argnums, has_aux, holomorphic]) | Jacobian of fun evaluated column-by-column using forward-mode AD. |
jacrev(fun[, argnums, has_aux, holomorphic, ...]) | Jacobian of fun evaluated row-by-row using reverse-mode AD. |
hessian(fun[, argnums, has_aux, holomorphic]) | Hessian of fun as a dense array. |
jvp(fun, primals, tangents[, has_aux]) | Computes a (forward-mode) Jacobian-vector product of fun. |
linearize() | Produces a linear approximation to fun using jvp() and partial eval. |
linear_transpose(fun, *primals[, reduce_axes]) | Transpose a function that is promised to be linear. |
vjp() )) | Compute a (reverse-mode) vector-Jacobian product of fun. |
custom_gradient(fun) | Convenience function for defining custom VJP rules (aka custom gradients). |
closure_convert(fun, *example_args) | Closure conversion utility, for use with higher-order custom derivatives. |
checkpoint(fun, *[, prevent_cse, policy, ...]) | Make fun recompute internal linearization points when differentiated. |
Customization#
custom_jvp
#
custom_jvp(fun[, nondiff_argnums]) | Set up a JAX-transformable function for a custom JVP rule definition. |
---|---|
custom_jvp.defjvp(jvp[, symbolic_zeros]) | Define a custom JVP rule for the function represented by this instance. |
custom_jvp.defjvps(*jvps) | Convenience wrapper for defining JVPs for each argument separately. |
custom_vjp
#
custom_vjp(fun[, nondiff_argnums]) | Set up a JAX-transformable function for a custom VJP rule definition. |
---|---|
custom_vjp.defvjp(fwd, bwd[, ...]) | Define a custom VJP rule for the function represented by this instance. |
custom_batching
#
jax.Array (jax.Array
)#
Array properties and methods#
Array.addressable_shards | List of addressable shards. |
---|---|
Array.all([axis, out, keepdims, where]) | Test whether all array elements along a given axis evaluate to True. |
Array.any([axis, out, keepdims, where]) | Test whether any array elements along a given axis evaluate to True. |
Array.argmax([axis, out, keepdims]) | Return the index of the maximum value. |
Array.argmin([axis, out, keepdims]) | Return the index of the minimum value. |
Array.argpartition(kth[, axis]) | Return the indices that partially sort the array. |
Array.argsort([axis, kind, order, stable, ...]) | Return the indices that sort the array. |
Array.astype(dtype[, copy, device]) | Copy the array and cast to a specified dtype. |
Array.at | Helper property for index update functionality. |
Array.choose(choices[, out, mode]) | Construct an array choosing from elements of multiple arrays. |
Array.clip([min, max]) | Return an array whose values are limited to a specified range. |
Array.compress(condition[, axis, out, size, ...]) | Return selected slices of this array along given axis. |
Array.committed | Whether the array is committed or not. |
Array.conj() | Return the complex conjugate of the array. |
Array.conjugate() | Return the complex conjugate of the array. |
Array.copy() | Return a copy of the array. |
Array.copy_to_host_async() | Copies an Array to the host asynchronously. |
Array.cumprod([axis, dtype, out]) | Return the cumulative product of the array. |
Array.cumsum([axis, dtype, out]) | Return the cumulative sum of the array. |
Array.device | Array API-compatible device attribute. |
Array.diagonal([offset, axis1, axis2]) | Return the specified diagonal from the array. |
Array.dot(b, *[, precision, ...]) | Compute the dot product of two arrays. |
Array.dtype | The data type (numpy.dtype) of the array. |
Array.flat | Use flatten() instead. |
Array.flatten([order]) | Flatten array into a 1-dimensional shape. |
Array.global_shards | List of global shards. |
Array.imag | Return the imaginary part of the array. |
Array.is_fully_addressable | Is this Array fully addressable? |
Array.is_fully_replicated | Is this Array fully replicated? |
Array.item(*args) | Copy an element of an array to a standard Python scalar and return it. |
Array.itemsize | Length of one array element in bytes. |
Array.max([axis, out, keepdims, initial, where]) | Return the maximum of array elements along a given axis. |
Array.mean([axis, dtype, out, keepdims, where]) | Return the mean of array elements along a given axis. |
Array.min([axis, out, keepdims, initial, where]) | Return the minimum of array elements along a given axis. |
Array.nbytes | Total bytes consumed by the elements of the array. |
Array.ndim | The number of dimensions in the array. |
Array.nonzero(*[, fill_value, size]) | Return indices of nonzero elements of an array. |
Array.prod([axis, dtype, out, keepdims, ...]) | Return product of the array elements over a given axis. |
Array.ptp([axis, out, keepdims]) | Return the peak-to-peak range along a given axis. |
Array.ravel([order]) | Flatten array into a 1-dimensional shape. |
Array.real | Return the real part of the array. |
Array.repeat(repeats[, axis, ...]) | Construct an array from repeated elements. |
Array.reshape(*args[, order]) | Returns an array containing the same data with a new shape. |
Array.round([decimals, out]) | Round array elements to a given decimal. |
Array.searchsorted(v[, side, sorter, method]) | Perform a binary search within a sorted array. |
Array.shape | The shape of the array. |
Array.sharding | The sharding for the array. |
Array.size | The total number of elements in the array. |
Array.sort([axis, kind, order, stable, ...]) | Return a sorted copy of an array. |
Array.squeeze([axis]) | Remove one or more length-1 axes from array. |
Array.std([axis, dtype, out, ddof, ...]) | Compute the standard deviation along a given axis. |
Array.sum([axis, dtype, out, keepdims, ...]) | Sum of the elements of the array over a given axis. |
Array.swapaxes(axis1, axis2) | Swap two axes of an array. |
Array.take(indices[, axis, out, mode, ...]) | Take elements from an array. |
Array.to_device(device, *[, stream]) | Return a copy of the array on the specified device |
Array.trace([offset, axis1, axis2, dtype, out]) | Return the sum along the diagonal. |
Array.transpose(*args) | Returns a copy of the array with axes transposed. |
Array.var([axis, dtype, out, ddof, ...]) | Compute the variance along a given axis. |
Array.view([dtype, type]) | Return a bitwise copy of the array, viewed as a new dtype. |
Array.T | Compute the all-axis array transpose. |
Array.mT | Compute the (batched) matrix transpose. |
Vectorization (vmap
)#
vmap(fun[, in_axes, out_axes, axis_name, ...]) | Vectorizing map. |
---|---|
numpy.vectorize(pyfunc, *[, excluded, signature]) | Define a vectorized function with broadcasting. |
Parallelization (pmap
)#
pmap(fun[, axis_name, in_axes, out_axes, ...]) | Parallel map with support for collective operations. |
---|---|
devices([backend]) | Returns a list of all devices for a given backend. |
local_devices([process_index, backend, host_id]) | Like jax.devices(), but only returns devices local to a given process. |
process_index([backend]) | Returns the integer process index of this process. |
device_count([backend]) | Returns the total number of devices. |
local_device_count([backend]) | Returns the number of devices addressable by this process. |
process_count([backend]) | Returns the number of JAX processes associated with the backend. |
process_indices([backend]) | Returns the list of all JAX process indices associated with the backend. |
Callbacks#
pure_callback(callback, result_shape_dtypes, ...) | Calls a pure Python callback. |
---|---|
experimental.io_callback(callback, ...[, ...]) | Calls an impure Python callback. |
debug.callback(callback, *args[, ordered]) | Calls a stageable Python callback. |
debug.print(fmt, *args[, ordered]) | Prints values and works in staged out JAX functions. |
Miscellaneous#
Device | A descriptor of an available device. |
---|---|
print_environment_info([return_string]) | Returns a string containing local environment & JAX installation information. |
live_arrays([platform]) | Return all live arrays in the backend for platform. |
clear_caches() | Clear all compilation and staging caches. |