torch.func.hessian — PyTorch 2.7 documentation (original) (raw)
torch.func.hessian(func, argnums=0)[source]¶
Computes the Hessian of func
with respect to the arg(s) at indexargnum
via a forward-over-reverse strategy.
The forward-over-reverse strategy (composing jacfwd(jacrev(func))
) is a good default for good performance. It is possible to compute Hessians through other compositions of jacfwd() and jacrev() likejacfwd(jacfwd(func))
or jacrev(jacrev(func))
.
Parameters
- func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple _[_int]) – Optional, integer or tuple of integers, saying which arguments to get the Hessian with respect to. Default: 0.
Returns
Returns a function that takes in the same inputs as func
and returns the Hessian of func
with respect to the arg(s) atargnums
.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use jacrev(jacrev(func))
, which has better operator coverage.
A basic usage with a R^N -> R^1 function gives a N x N Hessian:
from torch.func import hessian def f(x): return x.sin().sum()
x = torch.randn(5) hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) assert torch.allclose(hess, torch.diag(-x.sin()))