jax.numpy.stack — JAX documentation (original) (raw)

jax.numpy.stack#

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#

Join arrays along a new axis.

JAX implementation of numpy.stack().

Parameters:

Returns:

the stacked result.

Return type:

Array

See also

Examples

x = jnp.array([1, 2, 3]) y = jnp.array([4, 5, 6]) jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)

unstack() performs the inverse operation:

arr = jnp.stack([x, y], axis=1) x, y = jnp.unstack(arr, axis=1) x Array([1, 2, 3], dtype=int32) y Array([4, 5, 6], dtype=int32)