jax.numpy.array_equal — JAX documentation (original) (raw)

jax.numpy.array_equal#

jax.numpy.array_equal(a1, a2, equal_nan=False)[source]#

Check if two arrays are element-wise equal.

JAX implementation of numpy.array_equal().

Parameters:

Returns:

Boolean scalar array indicating whether the input arrays are element-wise equal.

Return type:

Array

Examples

jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)