jax.numpy.reshape — JAX documentation (original) (raw)

jax.numpy.reshape#

jax.numpy.reshape(a, shape, order='C', *, copy=None, out_sharding=None)[source]#

Return a reshaped copy of an array.

JAX implementation of numpy.reshape(), implemented in terms ofjax.lax.reshape().

Parameters:

Returns:

reshaped copy of input array with the specified shape.

Return type:

Array

Notes

Unlike numpy.reshape(), jax.numpy.reshape() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)

You can use -1 to automatically compute a shape that is consistent with the input size:

jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)

The default ordering of axes in the reshape is C-style row-major ordering. To use Fortran-style column-major ordering, specify order='F':

jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)

For convenience, this functionality is also available via thejax.Array.reshape() method:

x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)