torch.func.jacrev — PyTorch 2.7 documentation (original) (raw)

torch.func.jacrev(func, argnums=0, *, has_aux=False, chunk_size=None, _preallocate_and_copy=False)[source]

Computes the Jacobian of func with respect to the arg(s) at indexargnum using reverse mode autodiff

Note

Using chunk_size=1 is equivalent to computing the jacobian row-by-row with a for-loop i.e. the constraints of vmap() are not applicable.

Parameters

Returns

Returns a function that takes in the same inputs as func and returns the Jacobian of func with respect to the arg(s) atargnums. If has_aux is True, then the returned function instead returns a (jacobian, aux) tuple where jacobianis the Jacobian and aux is auxiliary objects returned by func.

A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian

from torch.func import jacrev x = torch.randn(5) jacobian = jacrev(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected)

If you would like to compute the output of the function as well as the jacobian of the function, use the has_aux flag to return the output as an auxiliary object:

from torch.func import jacrev x = torch.randn(5)

def f(x): return x.sin()

def g(x): result = f(x) return result, result

jacobian_f, f_x = jacrev(g, has_aux=True)(x) assert torch.allclose(f_x, f(x))

jacrev() can be composed with vmap to produce batched Jacobians:

from torch.func import jacrev, vmap x = torch.randn(64, 5) jacobian = vmap(jacrev(torch.sin))(x) assert jacobian.shape == (64, 5, 5)

Additionally, jacrev() can be composed with itself to produce Hessians

from torch.func import jacrev def f(x): return x.sin().sum()

x = torch.randn(5) hessian = jacrev(jacrev(f))(x) assert torch.allclose(hessian, torch.diag(-x.sin()))

By default, jacrev() computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by using argnums:

from torch.func import jacrev def f(x, y): return x + y ** 2

x, y = torch.randn(5), torch.randn(5) jacobian = jacrev(f, argnums=1)(x, y) expected = torch.diag(2 * y) assert torch.allclose(jacobian, expected)

Additionally, passing a tuple to argnums will compute the Jacobian with respect to multiple arguments

from torch.func import jacrev def f(x, y): return x + y ** 2

x, y = torch.randn(5), torch.randn(5) jacobian = jacrev(f, argnums=(0, 1))(x, y) expectedX = torch.diag(torch.ones_like(x)) expectedY = torch.diag(2 * y) assert torch.allclose(jacobian[0], expectedX) assert torch.allclose(jacobian[1], expectedY)

Note

Using PyTorch torch.no_grad together with jacrev. Case 1: Using torch.no_grad inside a function:

def f(x): with torch.no_grad(): c = x ** 2 return x - c

In this case, jacrev(f)(x) will respect the inner torch.no_grad.

Case 2: Using jacrev inside torch.no_grad context manager:

with torch.no_grad(): jacrev(f)(x)

In this case, jacrev will respect the inner torch.no_grad, but not the outer one. This is because jacrev is a “function transform”: its result should not depend on the result of a context manager outside of f.