jax.scipy.stats.mode — JAX documentation (original) (raw)

jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#

Compute the mode (most common value) along an axis of an array.

JAX implementation of scipy.stats.mode().

Parameters:

Returns:

A tuple of arrays, (mode, count). mode is the array of modal values, and count is the number of times each value appears in the input array.

Return type:

ModeResult

Examples

x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) mode, count = jax.scipy.stats.mode(x) mode, count (Array(4, dtype=int32), Array(3, dtype=int32))

For multi dimensional arrays, jax.scipy.stats.mode computes the modeand the corresponding count along axis=0:

x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) mode, count = jax.scipy.stats.mode(x1) mode, count (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))

If axis=1, mode and count will be computed along axis 1.

mode, count = jax.scipy.stats.mode(x1, axis=1) mode, count (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))

By default, jax.scipy.stats.mode reduces the dimension of the result. To keep the dimensions same as that of the input array, the argumentkeepdims must be set to True.

mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) mode, count (Array([[1], [3], [2]], dtype=int32), Array([[3], [3], [3]], dtype=int32))