torch.split — PyTorch 2.7 documentation (original) (raw)
torch.split(tensor, split_size_or_sections, dim=0)[source][source]¶
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size_or_sections
is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim
is not divisible bysplit_size
.
If split_size_or_sections
is a list, then tensor will be split into len(split_size_or_sections)
chunks with sizes in dim
according to split_size_or_sections
.
Parameters
- tensor (Tensor) – tensor to split.
- split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk
- dim (int) – dimension along which to split the tensor.
Return type
tuple[torch.Tensor, …]
Example:
a = torch.arange(10).reshape(5, 2) a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))