jax.numpy.stack — JAX documentation (original) (raw)
jax.numpy.stack#
jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#
Join arrays along a new axis.
JAX implementation of numpy.stack().
Parameters:
- arrays (np.ndarray | Array | Sequence [ ArrayLike ]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently toarrays = unstack(arrays), but the implementation will avoid explicit unstacking.
- axis (int) – specify the axis along which to stack.
- out (None | None) – unused by JAX
- dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
Returns:
the stacked result.
Return type:
See also
- jax.numpy.unstack(): inverse of
stack
. - jax.numpy.concatenate(): concatenation along existing axes.
- jax.numpy.vstack(): stack vertically, i.e. along axis 0.
- jax.numpy.hstack(): stack horizontally, i.e. along axis 1.
- jax.numpy.dstack(): stack depth-wise, i.e. along axis 2.
- jax.numpy.column_stack(): stack columns.
Examples
x = jnp.array([1, 2, 3]) y = jnp.array([4, 5, 6]) jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack() performs the inverse operation:
arr = jnp.stack([x, y], axis=1) x, y = jnp.unstack(arr, axis=1) x Array([1, 2, 3], dtype=int32) y Array([4, 5, 6], dtype=int32)