torch.split — PyTorch 2.7 documentation (original) (raw)

torch.split(tensor, split_size_or_sections, dim=0)[source][source]

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible bysplit_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

Parameters

Return type

tuple[torch.Tensor, …]

Example:

a = torch.arange(10).reshape(5, 2) a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))