jax.ops.segment_sum — JAX documentation (original) (raw)

jax.ops.segment_sum#

jax.ops.segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#

Computes the sum within segments of an array.

Similar to TensorFlow’s segment_sum

Parameters:

Returns:

An array with shape (num_segments,) + data.shape[1:] representing the segment sums.

Return type:

Array

Examples

Simple 1D segment sum:

data = jnp.arange(5) segment_ids = jnp.array([0, 0, 1, 1, 2]) segment_sum(data, segment_ids) Array([1, 5, 4], dtype=int32)

Using JIT requires static num_segments:

from jax import jit jit(segment_sum, static_argnums=2)(data, segment_ids, 3) Array([1, 5, 4], dtype=int32)