jax.numpy.copy — JAX documentation (original) (raw)

jax.numpy.copy#

jax.numpy.copy(a, order=None)[source]#

Return a copy of the array.

JAX implementation of numpy.copy().

Parameters:

Returns:

a copy of the input array a.

Return type:

Array

See also

Examples

Since JAX arrays are immutable, in most cases explicit array copies are not necessary. One exception is when using a function with donated arguments (see the donate_argnums argument to jax.jit()).

f = jax.jit(lambda x: 2 * x, donate_argnums=0) x = jnp.arange(4) y = f(x) print(y) [0 2 4 6]

Because we marked x as being donated, the original array is no longer available:

print(x)
Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4].

In situations like this, an explicit copy will let you keep access to the original buffer:

x = jnp.arange(4) y = f(x.copy()) print(y) [0 2 4 6] print(x) [0 1 2 3]