jax.Array.reshape — JAX documentation (original) (raw)
jax.Array.reshape#
abstract Array.reshape(*args, order='C', out_sharding=None)[source]#
Returns an array containing the same data with a new shape.
Refer to jax.numpy.reshape() for full documentation.
Parameters:
Return type: