torch.trace β€” PyTorch 2.7 documentation (original) (raw)

torch.trace(input) β†’ TensorΒΆ

Returns the sum of the elements of the diagonal of the input 2-D matrix.

Example:

x = torch.arange(1., 10.).view(3, 3) x tensor([[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]]) torch.trace(x) tensor(15.)