jax.scipy.stats.mode — JAX documentation (original) (raw)
jax.scipy.stats.mode#
jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#
Compute the mode (most common value) along an axis of an array.
JAX implementation of scipy.stats.mode().
Parameters:
- a (ArrayLike) – arraylike
- axis (int | None) – int, default=0. Axis along which to compute the mode.
- nan_policy (str) – str. JAX only supports
"propagate"
. - keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
Returns:
A tuple of arrays, (mode, count)
. mode
is the array of modal values, and count
is the number of times each value appears in the input array.
Return type:
ModeResult
Examples
x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) mode, count = jax.scipy.stats.mode(x) mode, count (Array(4, dtype=int32), Array(3, dtype=int32))
For multi dimensional arrays, jax.scipy.stats.mode
computes the mode
and the corresponding count
along axis=0
:
x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) mode, count = jax.scipy.stats.mode(x1) mode, count (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
If axis=1
, mode
and count
will be computed along axis 1
.
mode, count = jax.scipy.stats.mode(x1, axis=1) mode, count (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
By default, jax.scipy.stats.mode
reduces the dimension of the result. To keep the dimensions same as that of the input array, the argumentkeepdims
must be set to True
.
mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) mode, count (Array([[1], [3], [2]], dtype=int32), Array([[3], [3], [3]], dtype=int32))