jax.Array.argpartition — JAX documentation (original) (raw)
jax.Array.argpartition#
abstract Array.argpartition(kth, axis=-1)[source]#
Return the indices that partially sort the array.
Refer to jax.numpy.argpartition() for the full documentation.
Parameters:
Return type: