torch.diag — PyTorch 2.7 documentation (original) (raw)

torch.diag(input, diagonal=0, *, out=None) → Tensor

The argument diagonal controls which diagonal to consider:

Parameters

Keyword Arguments

out (Tensor, optional) – the output tensor.

Examples:

Get the square matrix where the input vector is the diagonal:

a = torch.randn(3) a tensor([ 0.5950,-0.0872, 2.3298]) torch.diag(a) tensor([[ 0.5950, 0.0000, 0.0000], [ 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 2.3298]]) torch.diag(a, 1) tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], [ 0.0000, 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 0.0000, 2.3298], [ 0.0000, 0.0000, 0.0000, 0.0000]])

Get the k-th diagonal of a given matrix:

a = torch.randn(3, 3) a tensor([[-0.4264, 0.0255,-0.1064], [ 0.8795,-0.2429, 0.1374], [ 0.1029,-0.6482,-1.6300]]) torch.diag(a, 0) tensor([-0.4264,-0.2429,-1.6300]) torch.diag(a, 1) tensor([ 0.0255, 0.1374])