torch.func.vjp — PyTorch 2.7 documentation (original) (raw)

torch.func.vjp(func, *primals, has_aux=False)[source]

Standing for the vector-Jacobian product, returns a tuple containing the results of func applied to primals and a function that, when given cotangents, computes the reverse-mode Jacobian of func with respect to primals times cotangents.

Parameters

Returns

Returns a (output, vjp_fn) tuple containing the output of funcapplied to primals and a function that computes the vjp offunc with respect to all primals using the cotangents passed to the returned function. If has_aux is True, then instead returns a(output, vjp_fn, aux) tuple. The returned vjp_fn function will return a tuple of each VJP.

When used in simple cases, vjp() behaves the same as grad()

x = torch.randn([5]) f = lambda x: x.sin().sum() (_, vjpfunc) = torch.func.vjp(f, x) grad = vjpfunc(torch.tensor(1.))[0] assert torch.allclose(grad, torch.func.grad(f)(x))

However, vjp() can support functions with multiple outputs by passing in the cotangents for each of the outputs

x = torch.randn([5]) f = lambda x: (x.sin(), x.cos()) (_, vjpfunc) = torch.func.vjp(f, x) vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() can even support outputs being Python structs

x = torch.randn([5]) f = lambda x: {'first': x.sin(), 'second': x.cos()} (_, vjpfunc) = torch.func.vjp(f, x) cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} vjps = vjpfunc(cotangents) assert torch.allclose(vjps[0], x.cos() + -x.sin())

The function returned by vjp() will compute the partials with respect to each of the primals

x, y = torch.randn([5, 4]), torch.randn([4, 5]) (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) cotangents = torch.randn([5, 5]) vjps = vjpfunc(cotangents) assert len(vjps) == 2 assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primals are the positional arguments for f. All kwargs use their default value

x = torch.randn([5]) def f(x, scale=4.): return x * scale

(_, vjpfunc) = torch.func.vjp(f, x) vjps = vjpfunc(torch.ones_like(x)) assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

Note

Using PyTorch torch.no_grad together with vjp. Case 1: Using torch.no_grad inside a function:

def f(x): with torch.no_grad(): c = x ** 2 return x - c

In this case, vjp(f)(x) will respect the inner torch.no_grad.

Case 2: Using vjp inside torch.no_grad context manager:

with torch.no_grad(): vjp(f)(x)

In this case, vjp will respect the inner torch.no_grad, but not the outer one. This is because vjp is a “function transform”: its result should not depend on the result of a context manager outside of f.