torch.swapaxes — PyTorch 2.7 documentation (original) (raw)

torch.swapaxes(input, axis0, axis1) → Tensor

Alias for torch.transpose().

This function is equivalent to NumPy’s swapaxes function.

Examples:

x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) x tensor([[[0, 1], [2, 3]],

    [[4, 5],
    [6, 7]]])

torch.swapaxes(x, 0, 1) tensor([[[0, 1], [4, 5]],

    [[2, 3],
    [6, 7]]])

torch.swapaxes(x, 0, 2) tensor([[[0, 4], [2, 6]],

    [[1, 5],
    [3, 7]]])