jax.ops.segment_min — JAX documentation (original) (raw)

jax.ops.segment_min#

jax.ops.segment_min(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#

Computes the minimum within segments of an array.

Similar to TensorFlow’s segment_min

Parameters:

Returns:

An array with shape (num_segments,) + data.shape[1:] representing the segment minimums.

Return type:

Array

Examples

Simple 1D segment min:

data = jnp.arange(6) segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) segment_min(data, segment_ids) Array([0, 2, 4], dtype=int32)

Using JIT requires static num_segments:

from jax import jit jit(segment_min, static_argnums=2)(data, segment_ids, 3) Array([0, 2, 4], dtype=int32)