API Reference — JAX documentation (original) (raw)
API Reference#
Subpackages#
- jax.numpy module
- jax.scipy module
- jax.lax module
- jax.random module
- jax.sharding module
- jax.ad_checkpoint module
- jax.debug module
- jax.dlpack module
- jax.distributed module
- jax.dtypes module
- jax.ffi module
- jax.flatten_util module
- jax.image module
- jax.nn module
- jax.ops module
- jax.profiler module
- jax.ref module
- jax.stages module
- jax.test_util module
- jax.tree module
- jax.tree_util module
- jax.typing module
- jax.export module
- jax.extend module
- jax.example_libraries module
- jax.experimental module
Configuration#
| config | |
|---|---|
| check_tracer_leaks | Context manager for jax_check_tracer_leaks config option. |
| checking_leaks | Context manager for jax_check_tracer_leaks config option. |
| debug_nans | Context manager for jax_debug_nans config option. |
| debug_infs | Context manager for jax_debug_infs config option. |
| default_device | Context manager for jax_default_device config option. |
| default_matmul_precision | Context manager for jax_default_matmul_precision config option. |
| default_prng_impl | Context manager for jax_default_prng_impl config option. |
| enable_checks | Context manager for jax_enable_checks config option. |
| enable_custom_prng | Context manager for jax_enable_custom_prng config option (transient). |
| enable_custom_vjp_by_custom_transpose | Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient). |
| enable_x64 | Context manager for jax_enable_x64 config option. |
| log_compiles | Context manager for jax_log_compiles config option. |
| no_tracing | Context manager for jax_no_tracing config option. |
| numpy_rank_promotion | Context manager for jax_numpy_rank_promotion config option. |
| transfer_guard(new_val) | A contextmanager to control the transfer guard level for all transfers. |
Just-in-time compilation (jit)#
| jit() | Sets up fun for just-in-time compilation with XLA. |
|---|---|
| disable_jit([disable]) | Context manager that disables jit() behavior under its dynamic context. |
| ensure_compile_time_eval() | Context manager to ensure evaluation at trace/compile time (or error). |
| make_jaxpr([axis_env, return_shape]) | Create a function that returns the jaxpr of fun given example args. |
| eval_shape(fun, *args, **kwargs) | Compute the shape/dtype of fun without any FLOPs. |
| ShapeDtypeStruct(shape, dtype, *[, ...]) | A container for the shape, dtype, and other static attributes of an array. |
| device_put(x[, device, src, donate, may_alias]) | Transfers x to device. |
| device_get(x) | Transfer x to host. |
| default_backend() | Returns the platform name of the default XLA backend. |
| named_call(fun, *[, name]) | Adds a user specified name to a function when staging out JAX computations. |
| named_scope(name) | A context manager that adds a user specified name to the JAX name stack. |
| block_until_ready(x) | Tries to call a block_until_ready method on pytree leaves. |
| copy_to_host_async(x) | Tries to call a copy_to_host_async method on pytree leaves. |
| make_mesh(axis_shapes, axis_names[, ...]) | Creates an efficient mesh with the shape and axis names specified. |
| set_mesh(mesh) | Sets a concrete mesh in a thread-local context. |
Automatic differentiation#
| grad(fun[, argnums, has_aux, holomorphic, ...]) | Creates a function that evaluates the gradient of fun. |
|---|---|
| value_and_grad(fun[, argnums, has_aux, ...]) | Create a function that evaluates both fun and the gradient of fun. |
| jacobian(fun[, argnums, has_aux, ...]) | Alias of jax.jacrev(). |
| jacfwd(fun[, argnums, has_aux, holomorphic]) | Jacobian of fun evaluated column-by-column using forward-mode AD. |
| jacrev(fun[, argnums, has_aux, holomorphic, ...]) | Jacobian of fun evaluated row-by-row using reverse-mode AD. |
| hessian(fun[, argnums, has_aux, holomorphic]) | Hessian of fun as a dense array. |
| jvp(fun, primals, tangents[, has_aux]) | Computes a (forward-mode) Jacobian-vector product of fun. |
| linearize() | Produces a linear approximation to fun using jvp() and partial eval. |
| linear_transpose(fun, *primals[, reduce_axes]) | Transpose a function that is promised to be linear. |
| vjp() )) | Compute a (reverse-mode) vector-Jacobian product of fun. |
| custom_gradient(fun) | Convenience function for defining custom VJP rules (aka custom gradients). |
| closure_convert(fun, *example_args) | Closure conversion utility, for use with higher-order custom derivatives. |
| checkpoint(fun, *[, prevent_cse, policy, ...]) | Make fun recompute internal linearization points when differentiated. |
Vectorization#
| vmap(fun[, in_axes, out_axes, axis_name, ...]) | Vectorizing map. |
|---|---|
| numpy.vectorize(pyfunc, *[, excluded, signature]) | Define a vectorized function with broadcasting. |
Parallelization#
| shard_map([f, in_specs, mesh, axis_names, ...]) | Map a function over shards of data using a mesh of devices. |
|---|---|
| smap([f, in_axes]) | Single axis shard_map that maps a function f one axis at a time. |
| pmap(fun[, axis_name, in_axes, out_axes, ...]) | Old way of doing parallel map. |
| devices([backend]) | Returns a list of all devices for a given backend. |
| local_devices([process_index, backend, host_id]) | Like jax.devices(), but only returns devices local to a given process. |
| process_index([backend]) | Returns the integer process index of this process. |
| device_count([backend]) | Returns the total number of devices. |
| local_device_count([backend]) | Returns the number of devices addressable by this process. |
| process_count([backend]) | Returns the number of JAX processes associated with the backend. |
| process_indices([backend]) | Returns the list of all JAX process indices associated with the backend. |
Customization#
custom_jvp#
| custom_jvp(fun[, nondiff_argnums, ...]) | Set up a JAX-transformable function for a custom JVP rule definition. |
|---|---|
| custom_jvp.defjvp(jvp[, symbolic_zeros]) | Define a custom JVP rule for the function represented by this instance. |
| custom_jvp.defjvps(*jvps) | Convenience wrapper for defining JVPs for each argument separately. |
custom_vjp#
| custom_vjp(fun[, nondiff_argnums, ...]) | Set up a JAX-transformable function for a custom VJP rule definition. |
|---|---|
| custom_vjp.defvjp(fwd, bwd[, ...]) | Define a custom VJP rule for the function represented by this instance. |
custom_batching#
jax.Array (jax.Array)#
Array properties and methods#
| Array.addressable_shards | List of addressable shards. |
|---|---|
| Array.all([axis, out, keepdims, where]) | Test whether all array elements along a given axis evaluate to True. |
| Array.any([axis, out, keepdims, where]) | Test whether any array elements along a given axis evaluate to True. |
| Array.argmax([axis, out, keepdims]) | Return the index of the maximum value. |
| Array.argmin([axis, out, keepdims]) | Return the index of the minimum value. |
| Array.argpartition(kth[, axis]) | Return the indices that partially sort the array. |
| Array.argsort([axis, kind, order, stable, ...]) | Return the indices that sort the array. |
| Array.astype(dtype[, copy, device]) | Copy the array and cast to a specified dtype. |
| Array.at | Helper property for index update functionality. |
| Array.choose(choices[, out, mode]) | Construct an array choosing from elements of multiple arrays. |
| Array.clip([min, max]) | Return an array whose values are limited to a specified range. |
| Array.compress(condition[, axis, out, size, ...]) | Return selected slices of this array along given axis. |
| Array.committed | Whether the array is committed or not. |
| Array.conj() | Return the complex conjugate of the array. |
| Array.conjugate() | Return the complex conjugate of the array. |
| Array.copy() | Return a copy of the array. |
| Array.copy_to_host_async() | Copies an Array to the host asynchronously. |
| Array.cumprod([axis, dtype, out]) | Return the cumulative product of the array. |
| Array.cumsum([axis, dtype, out]) | Return the cumulative sum of the array. |
| Array.device | Array API-compatible device attribute. |
| Array.diagonal([offset, axis1, axis2]) | Return the specified diagonal from the array. |
| Array.dot(b, *[, precision, ...]) | Compute the dot product of two arrays. |
| Array.dtype | The data type (numpy.dtype) of the array. |
| Array.flat | Use flatten() instead. |
| Array.flatten([order, out_sharding]) | Flatten array into a 1-dimensional shape. |
| Array.global_shards | List of global shards. |
| Array.imag | Return the imaginary part of the array. |
| Array.is_fully_addressable | Is this Array fully addressable? |
| Array.is_fully_replicated | Is this Array fully replicated? |
| Array.item(*args) | Copy an element of an array to a standard Python scalar and return it. |
| Array.itemsize | Length of one array element in bytes. |
| Array.max([axis, out, keepdims, initial, where]) | Return the maximum of array elements along a given axis. |
| Array.mean([axis, dtype, out, keepdims, where]) | Return the mean of array elements along a given axis. |
| Array.min([axis, out, keepdims, initial, where]) | Return the minimum of array elements along a given axis. |
| Array.nbytes | Total bytes consumed by the elements of the array. |
| Array.ndim | The number of dimensions in the array. |
| Array.nonzero(*[, fill_value, size]) | Return indices of nonzero elements of an array. |
| Array.prod([axis, dtype, out, keepdims, ...]) | Return product of the array elements over a given axis. |
| Array.ptp([axis, out, keepdims]) | Return the peak-to-peak range along a given axis. |
| Array.ravel([order, out_sharding]) | Flatten array into a 1-dimensional shape. |
| Array.real | Return the real part of the array. |
| Array.repeat(repeats[, axis, ...]) | Construct an array from repeated elements. |
| Array.reshape(*args[, order, out_sharding]) | Returns an array containing the same data with a new shape. |
| Array.round([decimals, out]) | Round array elements to a given decimal. |
| Array.searchsorted(v[, side, sorter, method]) | Perform a binary search within a sorted array. |
| Array.shape | The shape of the array. |
| Array.sharding | The sharding for the array. |
| Array.size | The total number of elements in the array. |
| Array.sort([axis, kind, order, stable, ...]) | Return a sorted copy of an array. |
| Array.squeeze([axis]) | Remove one or more length-1 axes from array. |
| Array.std([axis, dtype, out, ddof, ...]) | Compute the standard deviation along a given axis. |
| Array.sum([axis, dtype, out, keepdims, ...]) | Sum of the elements of the array over a given axis. |
| Array.swapaxes(axis1, axis2) | Swap two axes of an array. |
| Array.take(indices[, axis, out, mode, ...]) | Take elements from an array. |
| Array.to_device(device, *[, stream]) | Return a copy of the array on the specified device |
| Array.trace([offset, axis1, axis2, dtype, out]) | Return the sum along the diagonal. |
| Array.transpose(*args) | Returns a copy of the array with axes transposed. |
| Array.var([axis, dtype, out, ddof, ...]) | Compute the variance along a given axis. |
| Array.view([dtype, type]) | Return a bitwise copy of the array, viewed as a new dtype. |
| Array.T | Compute the all-axis array transpose. |
| Array.mT | Compute the (batched) matrix transpose. |
Callbacks#
| pure_callback(callback, result_shape_dtypes, ...) | Calls a pure Python callback. |
|---|---|
| experimental.io_callback(callback, ...[, ...]) | Calls an impure Python callback. |
| debug.callback(callback, *args[, ordered, ...]) | Calls a stageable Python callback. |
| debug.print(fmt, *args[, ordered, ...]) | Prints values and works in staged out JAX functions. |
Miscellaneous#
| Device | A descriptor of an available device. |
|---|---|
| print_environment_info([return_string]) | Returns a string containing local environment & JAX installation information. |
| live_arrays([platform]) | Return all live arrays in the backend for platform. |
| clear_caches() | Clear all compilation and staging caches. |
| typeof(x, /) | Return the JAX type (i.e. AbstractValue) of the input. |
Checkpoint policies#
| checkpoint_policies.everything_saveable(**__) | The default strategy, as if jax.checkpoint were not being used at all. |
|---|---|
| checkpoint_policies.nothing_saveable(**__) | Rematerialize everything, as if a custom policy were not being used at all. |
| checkpoint_policies.dots_saveable(*_, **__) | |
| checkpoint_policies.checkpoint_dots(*_, **__) | |
| checkpoint_policies.dots_with_no_batch_dims_saveable(...) | This is a useful heuristic for transformers. |
| checkpoint_policies.checkpoint_dots_with_no_batch_dims(...) | This is a useful heuristic for transformers. |
| checkpoint_policies.save_any_names_but_these() | Save only named values, i.e. any outputs of checkpoint_name, excluding the names given. |
| checkpoint_policies.save_only_these_names() | Save only named values, and only among the names given. |
| checkpoint_policies.offload_dot_with_no_batch_dims(...) | Same as dots_with_no_batch_dims_saveable, but offload to CPU memory instead of recomputing. |
| checkpoint_policies.save_and_offload_only_these_names(*, ...) | Same as save_only_these_names, but offload to CPU memory instead of recomputing. |
| checkpoint_policies.save_from_both_policies(...) | Logical OR of the given policies. |