Segment CSR — pytorch_scatter 2.1.1 documentation (original) (raw)
torch_scatter.segment_csr(src: Tensor, indptr: Tensor, out: Tensor | None = None, reduce: str = 'sum') → Tensor[source]¶
Reduces all values from the src
tensor into out
within the ranges specified in the indptr
tensor along the last dimension ofindptr
. For each value in src
, its output index is specified by its index in src
for dimensions outside of indptr.dim() - 1
and by the corresponding range index in indptr
for dimensionindptr.dim() - 1
. The applied reduction is defined via the reduce
argument.
Formally, if src
and indptr
are \(n\)-dimensional and\(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and\((x_0, ..., x_{m-2}, y)\), respectively, then out
must be an\(n\)-dimensional tensor with size\((x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})\). Moreover, the values of indptr
must be between \(0\) and\(x_m\) in ascending order. The indptr
tensor supports broadcasting in case its dimensions do not match with src
.
For one-dimensional tensors with reduce="sum"
, the operation computes
\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.\]
Due to the use of index pointers, segment_csr() is the fastest method to apply for grouped reductions.
Parameters:
- src – The source tensor.
- indptr – The index pointers between elements to segment. The number of dimensions of
index
needs to be less than or equal tosrc
. - out – The destination tensor.
- reduce – The reduce operation (
"sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
Return type:
Tensor
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())