jax.numpy.sort — JAX documentation (original) (raw)
jax.numpy.sort#
jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#
Return a sorted copy of an array.
JAX implementation of numpy.sort().
Parameters:
- a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort
- axis (int | None) – integer axis along which to sort. Defaults to
-1
, i.e. the last axis. IfNone
, thena
is flattened before being sorted. - stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
- descending (bool) – boolean specifying whether to sort in descending order. Default=False.
- kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
- order (None) – not supported by JAX
Returns:
Sorted array of shape a.shape
(if axis
is an integer) or of shape(a.size,)
(if axis
is None).
Return type:
Examples
Simple 1-dimensional sort
x = jnp.array([1, 3, 5, 4, 2, 1]) jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
- jax.numpy.argsort(): return indices of sorted values.
- jax.numpy.lexsort(): lexicographical sort of multiple arrays.
- jax.lax.sort(): lower-level function wrapping XLA’s Sort operator.