jax.numpy.quantile — JAX documentation (original) (raw)
jax.numpy.quantile#
jax.numpy.quantile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[source]#
Compute the quantile of the data along the specified axis.
JAX implementation of numpy.quantile().
Parameters:
- a (ArrayLike) – N-dimensional array input.
- q (ArrayLike) – scalar or 1-dimensional array specifying the desired quantiles.
qshould contain floating-point values between0.0and1.0. - axis (int | tuple[_int,_ ... ] | None) – optional axis or tuple of axes along which to compute the quantile
- out (None) – not implemented by JAX; will error if not None
- overwrite_input (bool) – not implemented by JAX; will error if not False
- method (str) – specify the interpolation method to use. Options are one of
["linear", "lower", "higher", "midpoint", "nearest"]. default islinear. - keepdims (bool) – if True, then the returned array will have the same number of dimensions as the input. Default is False.
- interpolation (DeprecatedArg)
Returns:
An array containing the specified quantiles along the specified axes.
Return type:
See also
- jax.numpy.nanquantile(): compute the quantile while ignoring NaNs
- jax.numpy.percentile(): compute the percentile (0-100)
Examples
Computing the median and quartiles of an array, with linear interpolation:
x = jnp.arange(10) q = jnp.array([0.25, 0.5, 0.75]) jnp.quantile(x, q) Array([2.25, 4.5 , 6.75], dtype=float32)
Computing the quartiles using nearest-value interpolation:
jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32)