jax.numpy.unique — JAX documentation (original) (raw)

jax.numpy.unique#

jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None, sorted=True)[source]#

Return the unique values from an array.

JAX implementation of numpy.unique().

Because the size of the output of unique is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:

Returns:

An array or tuple of arrays, depending on the values of return_index, return_inverse, and return_counts. Returned values are

See also

Examples

x = jnp.array([3, 4, 1, 3, 1]) jnp.unique(x) Array([1, 3, 4], dtype=int32)

JIT compilation & the size argument

If you try this under jit() or another transformation, you will get an error because the output shape is dynamic:

jax.jit(jnp.unique)(x)
Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.

The issue is that the output of transformed functions must have static shapes. In order to make this work, you can pass a static size parameter:

jit_unique = jax.jit(jnp.unique, static_argnames=['size']) jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)

If your static size is smaller than the true number of unique values, they will be truncated.

jit_unique(x, size=2) Array([1, 3], dtype=int32)

If the static size is larger than the true number of unique values, they will be padded withfill_value, which defaults to the minimum unique value:

jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)

Multi-dimensional unique values

If you pass a multi-dimensional array to unique, it will be flattened by default:

M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) jnp.unique(M) Array([1, 2, 3], dtype=int32)

If you pass an axis keyword, you can find unique slices of the array along that axis:

jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)

Returning indices

If you set return_index=True, then unique returns the indices of the first occurrence of each unique value:

x = jnp.array([3, 4, 1, 3, 1]) values, indices = jnp.unique(x, return_index=True) print(values) [1 3 4] print(indices) [2 0 1] jnp.all(values == x[indices]) Array(True, dtype=bool)

In multiple dimensions, the unique values can be extracted with jax.numpy.take()evaluated along the specified axis:

values, indices = jnp.unique(M, axis=0, return_index=True) jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)

Returning inverse

If you set return_inverse=True, then unique returns the indices within the unique values for every entry in the input array:

x = jnp.array([3, 4, 1, 3, 1]) values, inverse = jnp.unique(x, return_inverse=True) print(values) [1 3 4] print(inverse) [1 2 0 1 0] jnp.all(values[inverse] == x) Array(True, dtype=bool)

In multiple dimensions, the input can be reconstructed usingjax.numpy.take():

values, inverse = jnp.unique(M, axis=0, return_inverse=True) jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)

Returning counts

If you set return_counts=True, then unique returns the number of occurrences within the input for every unique value:

x = jnp.array([3, 4, 1, 3, 1]) values, counts = jnp.unique(x, return_counts=True) print(values) [1 3 4] print(counts) [2 2 1]

For multi-dimensional arrays, this also returns a 1D array of counts indicating number of occurrences along the specified axis:

values, counts = jnp.unique(M, axis=0, return_counts=True) print(values) [[1 2] [2 3]] print(counts) [2 1]