torch.gather — PyTorch 2.7 documentation (original) (raw)
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor¶
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
input
and index
must have the same number of dimensions. It is also required that index.size(d) <= input.size(d)
for all dimensions d != dim
. out
will have the same shape as index
. Note that input
and index
do not broadcast against each other.
Parameters
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
Keyword Arguments
- sparse_grad (bool, optional) – If
True
, gradient w.r.t.input
will be a sparse tensor. - out (Tensor, optional) – the destination tensor
Example:
t = torch.tensor([[1, 2], [3, 4]]) torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]])