torch.Tensor.scatter_ — PyTorch 2.7 documentation (original) (raw)
Tensor.scatter_(dim, index, src, *, reduce=None) → Tensor¶
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
For a 3-D tensor, self
is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in gather().
self
, index
and src
(if it is a Tensor) should all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)
for all dimensions d
, and thatindex.size(d) <= self.size(d)
for all dimensions d != dim
. Note that index
and src
do not broadcast.
Moreover, as for gather(), the values of index
must be between 0
and self.size(dim) - 1
inclusive.
Warning
When indices are not unique, the behavior is non-deterministic (one of the values from src
will be picked arbitrarily) and the gradient will be incorrect (it will be propagated to all locations in the source that correspond to the same index)!
Note
The backward pass is implemented only for src.shape == index.shape
.
Additionally accepts an optional reduce
argument that allows specification of an optional reduction operation, which is applied to all values in the tensor src
into self
at the indices specified in the index
. For each value in src
, the reduction operation is applied to an index in self
which is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
Given a 3-D tensor and reduction using the multiplication operation, self
is updated as:
self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
Reducing with the addition operation is the same as usingscatter_add_().
Warning
The reduce argument with Tensor src
is deprecated and will be removed in a future PyTorch release. Please use scatter_reduce_()instead for more reduction options.
Parameters
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as
src
. When empty, the operation returnsself
unchanged. - src (Tensor) – the source element(s) to scatter.
Keyword Arguments
reduce (str, optional) – reduction operation to apply, can be either'add'
or 'multiply'
.
Example:
src = torch.arange(1, 11).reshape((2, 5)) src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) index = torch.tensor([[0, 1, 2, 0]]) torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]]) index = torch.tensor([[0, 1, 2], [0, 1, 4]]) torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]])
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='multiply') tensor([[2.0000, 2.0000, 2.4600, 2.0000], [2.0000, 2.0000, 2.0000, 2.4600]]) torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='add') tensor([[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]])
scatter_(dim, index, value, *, reduce=None) → Tensor:
Writes the value from value
into self
at the indices specified in the index
tensor. This operation is equivalent to the previous version, with the src
tensor filled entirely with value
.
Parameters
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as
src
. When empty, the operation returnsself
unchanged. - value (Scalar) – the value to scatter.
Keyword Arguments
reduce (str, optional) – reduction operation to apply, can be either'add'
or 'multiply'
.
Example:
index = torch.tensor([[0, 1]]) value = 2 torch.zeros(3, 5).scatter_(0, index, value) tensor([[2., 0., 0., 0., 0.], [0., 2., 0., 0., 0.], [0., 0., 0., 0., 0.]])