torch.diagonal_scatter — PyTorch 2.7 documentation (original) (raw)

torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) → Tensor

Embeds the values of the src tensor into input along the diagonal elements of input, with respect to dim1and dim2.

This function returns a tensor with fresh storage; it does not return a view.

The argument offset controls which diagonal to consider:

Parameters

Note

src must be of the proper size in order to be embedded into input. Specifically, it should have the same shape astorch.diagonal(input, offset, dim1, dim2)

Examples:

a = torch.zeros(3, 3) a tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])

torch.diagonal_scatter(a, torch.ones(3), 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])

torch.diagonal_scatter(a, torch.ones(2), 1) tensor([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]])