torch.func.hessian — PyTorch 2.7 documentation (original) (raw)

torch.func.hessian(func, argnums=0)[source]

Computes the Hessian of func with respect to the arg(s) at indexargnum via a forward-over-reverse strategy.

The forward-over-reverse strategy (composing jacfwd(jacrev(func))) is a good default for good performance. It is possible to compute Hessians through other compositions of jacfwd() and jacrev() likejacfwd(jacfwd(func)) or jacrev(jacrev(func)).

Parameters

Returns

Returns a function that takes in the same inputs as func and returns the Hessian of func with respect to the arg(s) atargnums.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use jacrev(jacrev(func)), which has better operator coverage.

A basic usage with a R^N -> R^1 function gives a N x N Hessian:

from torch.func import hessian def f(x): return x.sin().sum()

x = torch.randn(5) hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) assert torch.allclose(hess, torch.diag(-x.sin()))