Argmax on empty array returns 2147483647 · Issue #2899 · jax-ml/jax (original) (raw)

How to reproduce:

>>> jnp.argmax(jnp.array([]))
DeviceArray(2147483647, dtype=int32)

Confirmed to occur on CPU and TPU.

For reference, in numpy an exception is returned: "ValueError: attempt to get argmax of an empty sequence".