jax.numpy.unique_inverse — JAX documentation (original) (raw)
Contents
jax.numpy.unique_inverse#
jax.numpy.unique_inverse(x, /, *, size=None, fill_value=None)[source]#
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of numpy.unique_inverse(); this is equivalent to callingjax.numpy.unique() with return_inverse and equal_nan set to True.
Because the size of the output of unique_inverse is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.
Parameters:
- x (ArrayLike) – N-dimensional array from which unique values will be extracted.
- size (int | None) – if specified, return only the first
sizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value. - fill_value (ArrayLike | None) – when
sizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
Returns:
values:
an array of shape(n_unique,)containing the unique values fromx.inverse_indices:
An array of shapex.shape. Contains the indices withinvaluesof each value inx. For 1D inputs,values[inverse_indices]is equivalent tox.
Return type:
A tuple (values, indices, inverse_indices, counts), with the following properties
See also
- jax.numpy.unique(): general function for computing unique values.
- jax.numpy.unique_values(): compute only
values. - jax.numpy.unique_counts(): compute only
valuesandcounts. - jax.numpy.unique_all(): compute
values,indices,inverse_indices, andcounts.
Examples
Here we compute the unique values in a 1D array:
x = jnp.array([3, 4, 1, 3, 1]) result = jnp.unique_inverse(x)
The result is a NamedTuple with two named attributes. The values attribute contains the unique values from the array:
result.values Array([1, 3, 4], dtype=int32)
The indices attribute contains the indices of the unique values within the input array:
The inverse_indices attribute contains the indices of the input within values:
result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
For examples of the size and fill_value arguments, see jax.numpy.unique().