jax.typing module — JAX documentation (original) (raw)

jax.typing module#

The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported here, see https://docs.jax.dev/en/latest/jep/12049-type-annotations.html.

The currently-available types are:

We may add additional types here in future releases.

JAX Typing Best Practices#

When annotating JAX arrays in public API functions, we recommend using ArrayLikefor array inputs, and Array for array outputs.

For example, your function might look like this:

import numpy as np import jax.numpy as jnp from jax import Array from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:

Runtime type validation, Python 3.10 or newer:

if not isinstance(x, ArrayLike): raise TypeError(f"Expected arraylike input; got {x}")

Runtime type validation, any Python version:

if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)): raise TypeError(f"Expected arraylike input; got {x}")

Convert input to jax.Array:

x_arr = jnp.asarray(x)

... do some computation; JAX functions will return Array types:

result = x_arr.sum(0) / x_arr.shape[0]

return an Array

return result

Most of JAX’s public APIs follow this pattern. Note in particular that we recommend JAX functions to not accept sequences such as list or tuple in place of arrays, as this can cause extra overhead in JAX transforms like jit() and can behave in unexpected ways with batch-wise transforms like vmap() or jax.pmap(). For more information on this, see Non-array inputs NumPy vs JAX

List of Members#