torch.Tensor.scatter_add_ (original) (raw)
Adds all values from the tensor src
into self
at the indices specified in the index
tensor in a similar fashion asscatter_(). For each value in src
, it is added to an index in self
which is specified by its index in src
for dimension != dim
and by the corresponding value in index
fordimension = dim
.
self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
self
, index
and src
should have same number of dimensions. It is also required that index.size(d) <= src.size(d)
for all dimensions d
, and that index.size(d) <= self.size(d)
for all dimensionsd != dim
. Note that index
and src
do not broadcast.
src = torch.ones((2, 5)) index = torch.tensor([[0, 1, 2, 0, 0]]) torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[1., 0., 0., 1., 1.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[2., 0., 0., 1., 1.], [0., 2., 0., 0., 0.], [0., 0., 2., 1., 1.]])