torch.func.grad — PyTorch 2.7 documentation (original) (raw)

torch.func.grad(func, argnums=0, has_aux=False)[source]

grad operator helps computing gradients of func with respect to the input(s) specified by argnums. This operator can be nested to compute higher-order gradients.

Parameters

Returns

Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified has_aux equals True, tuple of gradients and output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of output gradients with respect to each argnums value is returned.

Return type

Callable

Example of using grad:

from torch.func import grad x = torch.randn([]) cos_x = grad(lambda x: torch.sin(x))(x) assert torch.allclose(cos_x, x.cos())

Second-order gradients

neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

from torch.func import grad, vmap batch_size, feature_size = 3, 5

def model(weights, feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target): y = model(weights, example) return ((y - target) ** 2).mean() # MSELoss

weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) inputs = (weights, examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

Example of using grad with has_aux and argnums:

from torch.func import grad def my_loss_func(y, y_pred): loss_per_sample = (0.5 * y_pred - y) ** 2 loss = loss_per_sample.mean() return loss, (y_pred, loss_per_sample)

fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) y_true = torch.rand(4) y_preds = torch.rand(4, requires_grad=True) out = fn(y_true, y_preds)

> output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

Note

Using PyTorch torch.no_grad together with grad.

Case 1: Using torch.no_grad inside a function:

def f(x): with torch.no_grad(): c = x ** 2 return x - c

In this case, grad(f)(x) will respect the inner torch.no_grad.

Case 2: Using grad inside torch.no_grad context manager:

with torch.no_grad(): grad(f)(x)

In this case, grad will respect the inner torch.no_grad, but not the outer one. This is because grad is a “function transform”: its result should not depend on the result of a context manager outside of f.