jax.numpy.put — JAX documentation (original) (raw)

jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[source]#

Put elements into an array at given indices.

JAX implementation of numpy.put().

The semantics of numpy.put() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set toFalse` by the user as a reminder of this API difference.

Parameters:

Returns:

A copy of a with specified entries updated.

Return type:

Array

See also

Examples

x = jnp.zeros(5, dtype=int) indices = jnp.array([0, 2, 4]) values = jnp.array([10, 20, 30]) jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)

This is equivalent to the following jax.numpy.ndarray.at indexing syntax:

x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)

There are two modes for handling out-of-bound indices. By default they are clipped:

indices = jnp.array([0, 2, 6]) jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)

Alternatively, they can be wrapped to the beginning of the array:

jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)

For N-dimensional inputs, the indices refer to the flattened array:

x = jnp.zeros((3, 5), dtype=int) indices = jnp.array([0, 7, 14]) jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)