torch.trace β PyTorch 2.7 documentation (original) (raw)
torch.trace(input) β TensorΒΆ
Returns the sum of the elements of the diagonal of the input 2-D matrix.
Example:
x = torch.arange(1., 10.).view(3, 3) x tensor([[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]]) torch.trace(x) tensor(15.)