Unflatten — PyTorch 2.7 documentation (original) (raw)

class torch.nn.Unflatten(dim, unflattened_size)[source][source]

Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.

Shape:

Parameters

Examples

input = torch.randn(2, 50)

With tuple of ints

m = nn.Sequential( nn.Linear(50, 50), nn.Unflatten(1, (2, 5, 5)) ) output = m(input) output.size() torch.Size([2, 2, 5, 5])

With torch.Size

m = nn.Sequential( nn.Linear(50, 50), nn.Unflatten(1, torch.Size([2, 5, 5])) ) output = m(input) output.size() torch.Size([2, 2, 5, 5])

With namedshape (tuple of tuples)

input = torch.randn(2, 50, names=('N', 'features')) unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) output = unflatten(input) output.size() torch.Size([2, 2, 5, 5])

NamedShape

alias of tuple[tuple[str, int]]