Argmax on empty array returns 2147483647 · Issue #2899 · jax-ml/jax (original) (raw)
How to reproduce:
>>> jnp.argmax(jnp.array([]))
DeviceArray(2147483647, dtype=int32)
Confirmed to occur on CPU and TPU.
For reference, in numpy an exception is returned: "ValueError: attempt to get argmax of an empty sequence".