jax.numpy.hstack — JAX documentation (original) (raw)
jax.numpy.hstack#
jax.numpy.hstack(tup, dtype=None)[source]#
Horizontally stack arrays.
JAX implementation of numpy.hstack().
For arrays of one or more dimensions, this is equivalent tojax.numpy.concatenate() with axis=1
.
Parameters:
- tup (np.ndarray | Array | Sequence [ ArrayLike ]) – a sequence of arrays to stack; each must have the same shape along all but the second axis. Input arrays will be promoted to at least rank 1. If a single array is given it will be treated equivalently totup = unstack(tup), but the implementation will avoid explicit unstacking.
- dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
Returns:
the stacked result.
Return type:
See also
- jax.numpy.stack(): stack along arbitrary axes
- jax.numpy.concatenate(): concatenation along existing axes.
- jax.numpy.vstack(): stack vertically, i.e. along axis 0.
- jax.numpy.dstack(): stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
jnp.hstack([1, 2, 3]) Array([1, 2, 3], dtype=int32, weak_type=True)
1D arrays:
x = jnp.arange(3) y = jnp.ones(3) jnp.hstack([x, y]) Array([0., 1., 2., 1., 1., 1.], dtype=float32)
2D arrays:
x = x.reshape(3, 1) y = y.reshape(3, 1) jnp.hstack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)