jax.numpy.argmax — JAX documentation (original) (raw)
jax.numpy.argmax#
jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[source]#
Return the index of the maximum value of an array.
JAX implementation of numpy.argmax().
Parameters:
- a (ArrayLike) – input array
- axis (int | None) – optional integer specifying the axis along which to find the maximum value. If
axisis not specified,awill be flattened. - out (None) – unused by JAX
- keepdims (bool | None) – if True, then return an array with the same number of dimensions as
a.
Returns:
an array containing the index of the maximum value along the specified axis.
Return type:
See also
- jax.numpy.argmin(): return the index of the minimum value.
- jax.numpy.nanargmax(): compute
argmaxwhile ignoring NaN values.
Note
When the maximum value occurs more than once along a particular axis, the smallest index is returned.
Examples
x = jnp.array([1, 3, 5, 4, 2]) jnp.argmax(x) Array(2, dtype=int32)
x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)