torch.index_select — PyTorch 2.7 documentation (original) (raw)
torch.index_select(input, dim, index, *, out=None) → Tensor¶
Returns a new tensor which indexes the input
tensor along dimensiondim
using the entries in index
which is a LongTensor.
The returned tensor has the same number of dimensions as the original tensor (input
). The dim
th dimension has the same size as the length of index
; other dimensions have the same size as in the original tensor.
Note
The returned tensor does not use the same storage as the original tensor. If out
has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.
Parameters
- input (Tensor) – the input tensor.
- dim (int) – the dimension in which we index
- index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index
Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
x = torch.randn(3, 4) x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) indices = torch.tensor([0, 2]) torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])