torch.Tensor.scatter_reduce_ — PyTorch 2.7 documentation (original) (raw)
Tensor.scatter_reduce_(dim, index, src, reduce, *, include_self=True) → Tensor¶
Reduces all values from the src
tensor to the indices specified in the index
tensor in the self
tensor using the applied reduction defined via the reduce
argument ("sum"
, "prod"
, "mean"
,"amax"
, "amin"
). For each value in src
, it is reduced to an index in self
which is specified by its index in src
fordimension != dim
and by the corresponding value in index
fordimension = dim
. If include_self="True"
, the values in the self
tensor are included in the reduction.
self
, index
and src
should all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)
for all dimensions d
, and thatindex.size(d) <= self.size(d)
for all dimensions d != dim
. Note that index
and src
do not broadcast.
For a 3-D tensor with reduce="sum"
and include_self=True
the output is given as:
self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
Note
This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.
Note
The backward pass is implemented only for src.shape == index.shape
.
Warning
This function is in beta and may change in the near future.
Parameters
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter and reduce.
- src (Tensor) – the source elements to scatter and reduce
- reduce (str) – the reduction operation to apply for non-unique indices (
"sum"
,"prod"
,"mean"
,"amax"
,"amin"
) - include_self (bool) – whether elements from the
self
tensor are included in the reduction
Example:
src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) input.scatter_reduce(0, index, src, reduce="sum") tensor([5., 14., 8., 4.]) input.scatter_reduce(0, index, src, reduce="sum", include_self=False) tensor([4., 12., 5., 4.]) input2 = torch.tensor([5., 4., 3., 2.]) input2.scatter_reduce(0, index, src, reduce="amax") tensor([5., 6., 5., 2.]) input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) tensor([3., 6., 5., 2.])