jax.numpy.unique — JAX documentation (original) (raw)
jax.numpy.unique#
jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None, sorted=True)[source]#
Return the unique values from an array.
JAX implementation of numpy.unique().
Because the size of the output of unique is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.
Parameters:
- ar (ArrayLike) – N-dimensional array from which unique values will be extracted.
- return_index (bool) – if True, also return the indices in
arwhere each value occurs - return_inverse (bool) – if True, also return the indices that can be used to reconstruct
arfrom the unique values. - return_counts (bool) – if True, also return the number of occurrences of each unique value.
- axis (int | None) – if specified, compute unique values along the specified axis. If None (default), then flatten
arbefore computing the unique values. - equal_nan (bool) – if True, consider NaN values equivalent when determining uniqueness.
- size (int | None) – if specified, return only the first
sizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value. - fill_value (ArrayLike | None) – when
sizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value. - sorted (bool) – unused by JAX.
Returns:
An array or tuple of arrays, depending on the values of return_index, return_inverse, and return_counts. Returned values are
unique_values:
ifaxisis None, a 1D array of lengthn_unique, Ifaxisis specified, shape is(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:]).unique_index:
(returned only if return_index is True) An array of shape(n_unique,). Contains the indices of the first occurrence of each unique value inar. For 1D inputs,ar[unique_index]is equivalent tounique_values.unique_inverse:
(returned only if return_inverse is True) An array of shape(ar.size,)ifaxisis None, or of shape(ar.shape[axis],)ifaxisis specified. Contains the indices withinunique_valuesof each value inar. For 1D inputs,unique_values[unique_inverse]is equivalent toar.unique_counts:
(returned only if return_counts is True) An array of shape(n_unique,). Contains the number of occurrences of each unique value inar.
See also
- jax.numpy.unique_counts(): shortcut to
unique(arr, return_counts=True). - jax.numpy.unique_inverse(): shortcut to
unique(arr, return_inverse=True). - jax.numpy.unique_all(): shortcut to
uniquewith all return values. - jax.numpy.unique_values(): like
unique, but no optional return values.
Examples
x = jnp.array([3, 4, 1, 3, 1]) jnp.unique(x) Array([1, 3, 4], dtype=int32)
JIT compilation & the size argument
If you try this under jit() or another transformation, you will get an error because the output shape is dynamic:
jax.jit(jnp.unique)(x)
Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
The issue is that the output of transformed functions must have static shapes. In order to make this work, you can pass a static size parameter:
jit_unique = jax.jit(jnp.unique, static_argnames=['size']) jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)
If your static size is smaller than the true number of unique values, they will be truncated.
jit_unique(x, size=2) Array([1, 3], dtype=int32)
If the static size is larger than the true number of unique values, they will be padded withfill_value, which defaults to the minimum unique value:
jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)
Multi-dimensional unique values
If you pass a multi-dimensional array to unique, it will be flattened by default:
M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) jnp.unique(M) Array([1, 2, 3], dtype=int32)
If you pass an axis keyword, you can find unique slices of the array along that axis:
jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)
Returning indices
If you set return_index=True, then unique returns the indices of the first occurrence of each unique value:
x = jnp.array([3, 4, 1, 3, 1]) values, indices = jnp.unique(x, return_index=True) print(values) [1 3 4] print(indices) [2 0 1] jnp.all(values == x[indices]) Array(True, dtype=bool)
In multiple dimensions, the unique values can be extracted with jax.numpy.take()evaluated along the specified axis:
values, indices = jnp.unique(M, axis=0, return_index=True) jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)
Returning inverse
If you set return_inverse=True, then unique returns the indices within the unique values for every entry in the input array:
x = jnp.array([3, 4, 1, 3, 1]) values, inverse = jnp.unique(x, return_inverse=True) print(values) [1 3 4] print(inverse) [1 2 0 1 0] jnp.all(values[inverse] == x) Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed usingjax.numpy.take():
values, inverse = jnp.unique(M, axis=0, return_inverse=True) jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)
Returning counts
If you set return_counts=True, then unique returns the number of occurrences within the input for every unique value:
x = jnp.array([3, 4, 1, 3, 1]) values, counts = jnp.unique(x, return_counts=True) print(values) [1 3 4] print(counts) [2 2 1]
For multi-dimensional arrays, this also returns a 1D array of counts indicating number of occurrences along the specified axis:
values, counts = jnp.unique(M, axis=0, return_counts=True) print(values) [[1 2] [2 3]] print(counts) [2 1]