torch.autograd.function.FunctionCtx.set_materialize_grads — PyTorch 2.7 documentation (original) (raw)

FunctionCtx.set_materialize_grads(value)[source][source]

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context() orforward() methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward() and jvp() methods.

Example::

class SimpleFunc(Function): @staticmethod def forward(ctx, x): return x.clone(), x.clone()

@staticmethod
@once_differentiable
def backward(ctx, g1, g2):
    return g1 + g2  # No check for None necessary

We modify SimpleFunc to handle non-materialized grad outputs

class Func(Function): @staticmethod def forward(ctx, x): ctx.set_materialize_grads(False) ctx.save_for_backward(x) return x.clone(), x.clone()

@staticmethod
@once_differentiable
def backward(ctx, g1, g2):
    x, = ctx.saved_tensors
    grad_input = torch.zeros_like(x)
    if g1 is not None:  # We must check for None now
        grad_input += g1
    if g2 is not None:
        grad_input += g2
    return grad_input

a = torch.tensor(1., requires_grad=True) b, _ = Func.apply(a) # induces g2 to be undefined