torch.topk — PyTorch 2.7 documentation (original) (raw)
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)¶
Returns the k
largest elements of the given input
tensor along a given dimension.
If dim
is not given, the last dimension of the input is chosen.
If largest
is False
then the k smallest elements are returned.
A namedtuple of (values, indices) is returned with the values andindices of the largest k elements of each row of the input tensor in the given dimension dim.
The boolean option sorted
if True
, will make sure that the returnedk elements are themselves sorted
Note
When using torch.topk, the indices of tied elements are not guaranteed to be stable and may vary across different invocations.
Parameters
- input (Tensor) – the input tensor.
- k (int) – the k in “top-k”
- dim (int, optional) – the dimension to sort along
- largest (bool, optional) – controls whether to return largest or smallest elements
- sorted (bool, optional) – controls whether to return the elements in sorted order
Keyword Arguments
out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
Example:
x = torch.arange(1., 6.) x tensor([ 1., 2., 3., 4., 5.]) torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))