torch.autograd.function.FunctionCtx.set_materialize_grads — PyTorch 2.7 documentation (original) (raw)
FunctionCtx.set_materialize_grads(value)[source][source]¶
Set whether to materialize grad tensors. Default is True
.
This should be called only from either the setup_context()
orforward()
methods.
If True
, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward()
and jvp()
methods.
Example::
class SimpleFunc(Function): @staticmethod def forward(ctx, x): return x.clone(), x.clone()
@staticmethod @once_differentiable def backward(ctx, g1, g2): return g1 + g2 # No check for None necessary
We modify SimpleFunc to handle non-materialized grad outputs
class Func(Function): @staticmethod def forward(ctx, x): ctx.set_materialize_grads(False) ctx.save_for_backward(x) return x.clone(), x.clone()
@staticmethod @once_differentiable def backward(ctx, g1, g2): x, = ctx.saved_tensors grad_input = torch.zeros_like(x) if g1 is not None: # We must check for None now grad_input += g1 if g2 is not None: grad_input += g2 return grad_input
a = torch.tensor(1., requires_grad=True) b, _ = Func.apply(a) # induces g2 to be undefined