jax.numpy.intersect1d — JAX documentation (original) (raw)

jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#

Compute the set intersection of two 1D arrays.

JAX implementation of numpy.intersect1d().

Because the size of the output of intersect1d is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.intersect1d to be used in such contexts.

Parameters:

Returns:

An array intersection, or if return_indices=True, a tuple of arrays(intersection, ar1_indices, ar2_indices). Returned values are

Return type:

Array | tuple[Array, Array, Array]

See also

Examples

ar1 = jnp.array([1, 2, 3, 4]) ar2 = jnp.array([3, 4, 5, 6]) jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)

Computing intersection with indices:

intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) intersection Array([3, 4], dtype=int32)

ar1_indices gives the indices of the intersected values within ar1:

ar1_indices Array([2, 3], dtype=int32) jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)

ar2_indices gives the indices of the intersected values within ar2:

ar2_indices Array([0, 1], dtype=int32) jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)