jax.numpy.ndarray.at — JAX documentation (original) (raw)

jax.numpy.ndarray.at#

abstract property ndarray.at[source]#

Helper property for index update functionality.

The at property provides a functionally pure equivalent of in-place array modifications.

In particular:

Alternate syntax Equivalent In-place expression
x = x.at[idx].set(y) x[idx] = y
x = x.at[idx].add(y) x[idx] += y
x = x.at[idx].subtract(y) x[idx] -= y
x = x.at[idx].multiply(y) x[idx] *= y
x = x.at[idx].divide(y) x[idx] /= y
x = x.at[idx].power(y) x[idx] **= y
x = x.at[idx].min(y) x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y) x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc) ufunc.at(x, idx)
x = x.at[idx].get() x = x[idx]

None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).

Parameters:

Examples

x = jnp.arange(5.0) x Array([0., 1., 2., 3., 4.], dtype=float32) x.at[2].get() Array(2., dtype=float32) x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32)

By default, out-of-bound indices are ignored in updates, but this behavior can be controlled with the mode parameter:

x.at[10].add(10) # dropped Array([0., 1., 2., 3., 4.], dtype=float32) x.at[20].add(10, mode='clip') # clipped Array([ 0., 1., 2., 3., 14.], dtype=float32)

For get(), out-of-bound indices are clipped by default:

x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)

Negative indices count from the end of the array, but this behavior can be disabled by setting wrap_negative_indices = False:

x.at[-1].set(99) Array([ 0., 1., 2., 3., 99.], dtype=float32) x.at[-1].set(99, wrap_negative_indices=False, mode='drop') # dropped! Array([0., 1., 2., 3., 4.], dtype=float32)