[jax2tf] Fix conversion for argmin/argmax; add conversion for reduce by copybara-service[bot] · Pull Request #7196 · jax-ml/jax (original) (raw)
[jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.
In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.
Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!