torch.movedim — PyTorch 2.7 documentation (original) (raw)

torch.movedim(input, source, destination) → Tensor

Moves the dimension(s) of input at the position(s) in sourceto the position(s) in destination.

Other dimensions of input that are not explicitly moved remain in their original order and appear at the positions not specified in destination.

Parameters

Examples:

t = torch.randn(3,2,1) t tensor([[[-0.3362], [-0.8437]],

    [[-0.9627],
    [ 0.1727]],

    [[ 0.5173],
    [-0.1398]]])

torch.movedim(t, 1, 0).shape torch.Size([2, 3, 1]) torch.movedim(t, 1, 0) tensor([[[-0.3362], [-0.9627], [ 0.5173]],

    [[-0.8437],
    [ 0.1727],
    [-0.1398]]])

torch.movedim(t, (1, 2), (0, 1)).shape torch.Size([2, 1, 3]) torch.movedim(t, (1, 2), (0, 1)) tensor([[[-0.3362, -0.9627, 0.5173]],

    [[-0.8437,  0.1727, -0.1398]]])