jax.numpy.percentile — JAX documentation (original) (raw)

jax.numpy.percentile#

jax.numpy.percentile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[source]#

Compute the percentile of the data along the specified axis.

JAX implementation of numpy.percentile().

Parameters:

Returns:

An array containing the specified percentiles along the specified axes.

Return type:

Array

See also

Examples

Computing the median and quartiles of a 1D array:

x = jnp.array([0, 1, 2, 3, 4, 5, 6]) q = jnp.array([25, 50, 75]) jnp.percentile(x, q) Array([1.5, 3. , 4.5], dtype=float32)

Computing the same percentiles with nearest rather than linear interpolation:

jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32)