jax.numpy.copy — JAX documentation (original) (raw)
jax.numpy.copy#
jax.numpy.copy(a, order=None)[source]#
Return a copy of the array.
JAX implementation of numpy.copy().
Parameters:
- a (ArrayLike) – arraylike object to copy
- order (str | None | None) – not implemented in JAX
Returns:
a copy of the input array a
.
Return type:
See also
- jax.numpy.array(): create an array with or without a copy.
- jax.Array.copy(): same function accessed as an array method.
Examples
Since JAX arrays are immutable, in most cases explicit array copies are not necessary. One exception is when using a function with donated arguments (see the donate_argnums
argument to jax.jit()).
f = jax.jit(lambda x: 2 * x, donate_argnums=0) x = jnp.arange(4) y = f(x) print(y) [0 2 4 6]
Because we marked x
as being donated, the original array is no longer available:
print(x)
Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4].
In situations like this, an explicit copy will let you keep access to the original buffer:
x = jnp.arange(4) y = f(x.copy()) print(y) [0 2 4 6] print(x) [0 1 2 3]