jnp.ndarray: raise error for binary operations with lists and tuples by jakevdp · Pull Request #11234 · jax-ml/jax (original) (raw)

This fixes a potentially confusing issue when JAX arrays are compared to python sequences.

When numpy encounters a list or tuple, it implicitly casts it to an array:

In [1]: import numpy as np ...: print(np.array([1, 1]) == [1, 2])
[ True False]

In JAX, we tend not to do this for performance reasons, and because explicit is better than implicit in this case (see e.g. the discussion in #7737). Unfortunately, this case is not explicitly handled currently in binary operators, which makes the current behavior somewhat surprising:

In [2]: import jax.numpy as jnp ...: print(jnp.array([1, 1]) == [1, 2])
False

This comes from list.__req__, and is a surprising result.

With this PR, this kind of comparison becomes an explicit TypeError:

In [1]: import numpy as np ...: print(np.array([1, 1]) == [1, 2])
[ True False]

In [2]: import jax.numpy as jnp ...: print(jnp.array([1, 1]) == [1, 2])

TypeError Traceback (most recent call last) in 1 import jax.numpy as jnp ----> 2 print(jnp.array([1, 1]) == [1, 2])

~/github/google/jax/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other) 4583 other = other.jax_array() 4584 if isinstance(other, _rejected_binop_types): -> 4585 raise TypeError(f"unsupported operand type(s) for {opchar}: " 4586 f"{type(self).name!r} and {type(other).name!r}") 4587 if not isinstance(other, _accepted_binop_types):

TypeError: unsupported operand type(s) for ==: 'DeviceArray' and 'list'

This is a draft, mainly to see if the issue catches any tests. Still todo if we choose to make this change:

Closes #2406