jax.Array.argpartition — JAX documentation (original) (raw)

jax.Array.argpartition#

abstract Array.argpartition(kth, axis=-1)[source]#

Return the indices that partially sort the array.

Refer to jax.numpy.argpartition() for the full documentation.

Parameters:

Return type:

Array