torch.chunk — PyTorch 2.7 documentation (original) (raw)

torch.chunk(input: Tensor, chunks: int, dim: int = 0) → Tuple[Tensor, ...]

Attempts to split a tensor into the specified number of chunks. Each chunk is a view of the input tensor.

Note

This function may return fewer than the specified number of chunks!

If the tensor size along the given dimension dim is divisible by chunks, all returned chunks will be the same size. If the tensor size along the given dimension dim is not divisible by chunks, all returned chunks will be the same size, except the last one. If such division is not possible, this function may return fewer than the specified number of chunks.

Parameters

Example

torch.arange(11).chunk(6) (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]), tensor([10])) torch.arange(12).chunk(6) (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]), tensor([10, 11])) torch.arange(13).chunk(6) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]), tensor([12]))