jax.ops module — JAX documentation (original) (raw)
jax.ops
module#
The functions jax.ops.index_update
, jax.ops.index_add
, etc., which were deprecated in JAX 0.2.22, have been removed. Please use thejax.numpy.ndarray.at property on JAX arrays instead.
Segment reduction operators#
segment_max(data, segment_ids[, ...]) | Computes the maximum within segments of an array. |
---|---|
segment_min(data, segment_ids[, ...]) | Computes the minimum within segments of an array. |
segment_prod(data, segment_ids[, ...]) | Computes the product within segments of an array. |
segment_sum(data, segment_ids[, ...]) | Computes the sum within segments of an array. |