jax.numpy.unstack — JAX documentation (original) (raw)
jax.numpy.unstack#
jax.numpy.unstack(x, /, *, axis=0)[source]#
Unstack an array along an axis.
JAX implementation of array_api.unstack().
Parameters:
- x (ArrayLike) – array to unstack. Must have
x.ndim >= 1
. - axis (int) – integer axis along which to unstack. Must satisfy
-x.ndim <= axis < x.ndim
.
Returns:
tuple of unstacked arrays.
Return type:
See also
- jax.numpy.stack(): inverse of
unstack
- jax.numpy.split(): split array into batches along an axis.
Examples
arr = jnp.array([[1, 2, 3], ... [4, 5, 6]]) arrs = jnp.unstack(arr) print(*arrs) [1 2 3] [4 5 6]
stack() provides the inverse of this:
jnp.stack(arrs) Array([[1, 2, 3], [4, 5, 6]], dtype=int32)