jax.numpy.argmax — JAX documentation (original) (raw)

jax.numpy.argmax#

jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[source]#

Return the index of the maximum value of an array.

JAX implementation of numpy.argmax().

Parameters:

Returns:

an array containing the index of the maximum value along the specified axis.

Return type:

Array

See also

Examples

x = jnp.array([1, 3, 5, 4, 2]) jnp.argmax(x) Array(2, dtype=int32)

x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)

jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)