jax.numpy.unstack — JAX documentation (original) (raw)

jax.numpy.unstack#

jax.numpy.unstack(x, /, *, axis=0)[source]#

Unstack an array along an axis.

JAX implementation of array_api.unstack().

Parameters:

Returns:

tuple of unstacked arrays.

Return type:

tuple[Array, …]

See also

Examples

arr = jnp.array([[1, 2, 3], ... [4, 5, 6]]) arrs = jnp.unstack(arr) print(*arrs) [1 2 3] [4 5 6]

stack() provides the inverse of this:

jnp.stack(arrs) Array([[1, 2, 3], [4, 5, 6]], dtype=int32)