torch.unflatten — PyTorch 2.7 documentation (original) (raw)

torch.unflatten(input, dim, sizes) → Tensor

Expands a dimension of the input tensor over multiple dimensions.

See also

torch.flatten() the inverse of this function. It coalesces several dimensions into one.

Parameters

Returns

A View of input with the specified dimension unflattened.

Examples::

torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape torch.Size([3, 2, 2, 1]) torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape torch.Size([3, 2, 2, 1]) torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape torch.Size([5, 2, 2, 3, 1, 1, 3])