Flatten — PyTorch 2.7 documentation (original) (raw)

class torch.nn.Flatten(start_dim=1, end_dim=-1)[source][source]

Flattens a contiguous range of dims into a tensor.

For use with Sequential, see torch.flatten() for details.

Shape:

Parameters

Examples::

input = torch.randn(32, 1, 5, 5)

With default parameters

m = nn.Flatten() output = m(input) output.size() torch.Size([32, 25])

With non-default parameters

m = nn.Flatten(0, 2) output = m(input) output.size() torch.Size([160, 5])