jax.numpy.quantile — JAX documentation (original) (raw)

jax.numpy.quantile#

jax.numpy.quantile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[source]#

Compute the quantile of the data along the specified axis.

JAX implementation of numpy.quantile().

Parameters:

Returns:

An array containing the specified quantiles along the specified axes.

Return type:

Array

See also

Examples

Computing the median and quartiles of an array, with linear interpolation:

x = jnp.arange(10) q = jnp.array([0.25, 0.5, 0.75]) jnp.quantile(x, q) Array([2.25, 4.5 , 6.75], dtype=float32)

Computing the quartiles using nearest-value interpolation:

jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32)