jax.ops module — JAX documentation (original) (raw)
jax.ops module#
The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use thejax.numpy.ndarray.at property on JAX arrays instead.
Segment reduction operators#
| segment_max(data, segment_ids[, ...]) | Computes the maximum within segments of an array. |
|---|---|
| segment_min(data, segment_ids[, ...]) | Computes the minimum within segments of an array. |
| segment_prod(data, segment_ids[, ...]) | Computes the product within segments of an array. |
| segment_sum(data, segment_ids[, ...]) | Computes the sum within segments of an array. |