tf.argsort  |  TensorFlow v2.16.1 (original) (raw)

tf.argsort

Stay organized with collections Save and categorize content based on your preferences.

Returns the indices of a tensor that give its sorted order along an axis.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.argsort

tf.argsort(
    values, axis=-1, direction='ASCENDING', stable=False, name=None
)

Used in the notebooks

Used in the tutorials
MoViNet for streaming action recognition Image Classification with TensorFlow Hub Client-efficient large-model federated learning via `federated_select` and sparse aggregation

values = [1, 10, 26.9, 2.8, 166.32, 62.3] sort_order = tf.argsort(values) sort_order.numpy() array([0, 3, 1, 2, 5, 4], dtype=int32)

For a 1D tensor:

sorted = tf.gather(values, sort_order) assert tf.reduce_all(sorted == tf.sort(values))

For higher dimensions, the output has the same shape asvalues, but along the given axis, values represent the index of the sorted element in that slice of the tensor at the given position.

mat = [[30,20,10], [20,10,30], [10,30,20]] indices = tf.argsort(mat) indices.numpy() array([[2, 1, 0], [1, 0, 2], [0, 2, 1]], dtype=int32)

If axis=-1 these indices can be used to apply a sort using tf.gather:

tf.gather(mat, indices, batch_dims=-1).numpy() array([[10, 20, 30], [10, 20, 30], [10, 20, 30]], dtype=int32)

See also
tf.sort: Sort along an axis. tf.math.top_k: A partial sort that returns a fixed number of top values and corresponding indices.
Args
values 1-D or higher numeric Tensor.
axis The axis along which to sort. The default is -1, which sorts the last axis.
direction The direction in which to sort the values ('ASCENDING' or'DESCENDING').
stable If True, equal elements in the original tensor will not be re-ordered in the returned order. Unstable sort is not yet implemented, but will eventually be the default for performance reasons. If you require a stable order, pass stable=True for forwards compatibility.
name Optional name for the operation.
Returns
An int32 Tensor with the same shape as values. The indices that would sort each slice of the given values along the given axis.
Raises
ValueError If axis is not a constant scalar, or the direction is invalid.
tf.errors.InvalidArgumentError If the values.dtype is not a float orint type.