Unflatten — PyTorch 2.7 documentation (original) (raw)
class torch.nn.Unflatten(dim, unflattened_size)[source][source]¶
Unflattens a tensor dim expanding it to a desired shape. For use with Sequential
.
dim
specifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.unflattened_size
is the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape(tuple of (name, size) tuples) for NamedTensor input.
Shape:
- Input: (∗,Sdim,∗)(*, S_{\text{dim}}, *), where SdimS_{\text{dim}} is the size at dimension
dim
and ∗* means any number of dimensions including none. - Output: (∗,U1,...,Un,∗)(*, U_1, ..., U_n, *), where UU =
unflattened_size
and∏i=1nUi=Sdim\prod_{i=1}^n U_i = S_{\text{dim}}.
Parameters
- dim (Union [_int,_ str]) – Dimension to be unflattened
- unflattened_size (Union [_torch.Size,_ Tuple , List , NamedShape ]) – New shape of the unflattened dimension
Examples
input = torch.randn(2, 50)
With tuple of ints
m = nn.Sequential( nn.Linear(50, 50), nn.Unflatten(1, (2, 5, 5)) ) output = m(input) output.size() torch.Size([2, 2, 5, 5])
With torch.Size
m = nn.Sequential( nn.Linear(50, 50), nn.Unflatten(1, torch.Size([2, 5, 5])) ) output = m(input) output.size() torch.Size([2, 2, 5, 5])
With namedshape (tuple of tuples)
input = torch.randn(2, 50, names=('N', 'features')) unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) output = unflatten(input) output.size() torch.Size([2, 2, 5, 5])
NamedShape¶