jnp.ndarray: raise error for binary operations with lists and tuples by jakevdp · Pull Request #11234 · jax-ml/jax (original) (raw)
This fixes a potentially confusing issue when JAX arrays are compared to python sequences.
When numpy encounters a list or tuple, it implicitly casts it to an array:
In [1]: import numpy as np
...: print(np.array([1, 1]) == [1, 2])
[ True False]
In JAX, we tend not to do this for performance reasons, and because explicit is better than implicit in this case (see e.g. the discussion in #7737). Unfortunately, this case is not explicitly handled currently in binary operators, which makes the current behavior somewhat surprising:
In [2]: import jax.numpy as jnp
...: print(jnp.array([1, 1]) == [1, 2])
False
This comes from list.__req__
, and is a surprising result.
With this PR, this kind of comparison becomes an explicit TypeError
:
In [1]: import numpy as np
...: print(np.array([1, 1]) == [1, 2])
[ True False]
In [2]: import jax.numpy as jnp ...: print(jnp.array([1, 1]) == [1, 2])
TypeError Traceback (most recent call last) in 1 import jax.numpy as jnp ----> 2 print(jnp.array([1, 1]) == [1, 2])
~/github/google/jax/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other) 4583 other = other.jax_array() 4584 if isinstance(other, _rejected_binop_types): -> 4585 raise TypeError(f"unsupported operand type(s) for {opchar}: " 4586 f"{type(self).name!r} and {type(other).name!r}") 4587 if not isinstance(other, _accepted_binop_types):
TypeError: unsupported operand type(s) for ==: 'DeviceArray' and 'list'
This is a draft, mainly to see if the issue catches any tests. Still todo if we choose to make this change:
- add
set
anddict
- ensure arguments are in correct order in error message
- regression tests for this behavior
- CHANGELOG entry
Closes #2406