jax.numpy.sort — JAX documentation (original) (raw)

jax.numpy.sort#

jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#

Return a sorted copy of an array.

JAX implementation of numpy.sort().

Parameters:

Returns:

Sorted array of shape a.shape (if axis is an integer) or of shape(a.size,) (if axis is None).

Return type:

Array

Examples

Simple 1-dimensional sort

x = jnp.array([1, 3, 5, 4, 2, 1]) jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)

See also