jax.numpy.intersect1d — JAX documentation (original) (raw)
jax.numpy.intersect1d#
jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#
Compute the set intersection of two 1D arrays.
JAX implementation of numpy.intersect1d().
Because the size of the output of intersect1d is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.intersect1d to be used in such contexts.
Parameters:
- ar1 (ArrayLike) – first array of values to intersect.
- ar2 (ArrayLike) – second array of values to intersect.
- assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if
assume_uniqueis True and the input arrays contain duplicates, the behavior is undefined. default: False. - return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.
- size (int | None) – if specified, return only the first
sizesorted elements. If there are fewer elements thansizeindicates, the return value will be padded withfill_value, and returned indices will be padded with an out-of-bound index. - fill_value (ArrayLike | None) – when
sizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the smallest value in the intersection.
Returns:
An array intersection, or if return_indices=True, a tuple of arrays(intersection, ar1_indices, ar2_indices). Returned values are
intersection: A 1D array containing each value that appears in bothar1andar2.ar1_indices:(returned if return_indices=True) an array of shapeintersection.shapecontaining the indices in flattenedar1of values inintersection. For 1D inputs,intersectionis equivalent toar1[ar1_indices].ar2_indices:(returned if return_indices=True) an array of shapeintersection.shapecontaining the indices in flattenedar2of values inintersection. For 1D inputs,intersectionis equivalent toar2[ar2_indices].
Return type:
Array | tuple[Array, Array, Array]
See also
- jax.numpy.union1d(): the set union of two 1D arrays.
- jax.numpy.setxor1d(): the set XOR of two 1D arrays.
- jax.numpy.setdiff1d(): the set difference of two 1D arrays.
Examples
ar1 = jnp.array([1, 2, 3, 4]) ar2 = jnp.array([3, 4, 5, 6]) jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
Computing intersection with indices:
intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) intersection Array([3, 4], dtype=int32)
ar1_indices gives the indices of the intersected values within ar1:
ar1_indices Array([2, 3], dtype=int32) jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices gives the indices of the intersected values within ar2:
ar2_indices Array([0, 1], dtype=int32) jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)