jax.ops module — JAX documentation (original) (raw)

jax.ops module#

The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use thejax.numpy.ndarray.at property on JAX arrays instead.

Segment reduction operators#

segment_max(data, segment_ids[, ...]) Computes the maximum within segments of an array.
segment_min(data, segment_ids[, ...]) Computes the minimum within segments of an array.
segment_prod(data, segment_ids[, ...]) Computes the product within segments of an array.
segment_sum(data, segment_ids[, ...]) Computes the sum within segments of an array.