jax.hessian — JAX documentation (original) (raw)
jax.hessian#
jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Hessian of fun
as a dense array.
Parameters:
- fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by
argnums
should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof. - argnums (int | Sequence _[_int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
). - has_aux (bool) – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. - holomorphic (bool) – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.
Returns:
A function with the same arguments as fun
, that evaluates the Hessian offun
.
Return type:
Callable
import jax
g = lambda x: x[0]*3 - 2x[0]*x[1] - x[1]**6 print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian() is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure of jax.hessian(fun)(x)
is given by forming a tree product of the structure of fun(x)
with a tree product of two copies of the structure of x
. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:
import jax.numpy as jnp f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of jax.hessian(fun)(x)
corresponds to a leaf of fun(x)
and a pair of leaves of x
. For each leaf injax.hessian(fun)(x)
, if the corresponding array leaf of fun(x)
has shape (out_1, out_2, ...)
and the corresponding array leaves of x
have shape (in_1_1, in_1_2, ...)
and (in_2_1, in_2_2, ...)
respectively, then the Hessian leaf has shape (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
. In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.
In particular, an array is produced (with no pytrees involved) when the function input x
and output fun(x)
are each a single array, as in theg
example above. If fun(x)
has shape (out1, out2, ...)
and x
has shape (in1, in2, ...)
then jax.hessian(fun)(x)
has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...)
. To flatten pytrees into 1D vectors, consider using jax.flatten_util.flatten_pytree()
.