torch.autograd.function.FunctionCtx.mark_non_differentiable — PyTorch 2.7 documentation (original) (raw)

FunctionCtx.mark_non_differentiable(*args)[source][source]

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context()or forward() methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward(), but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::

class Func(Function): @staticmethod def forward(ctx, x): sorted, idx = x.sort() ctx.mark_non_differentiable(idx) ctx.save_for_backward(x, idx) return sorted, idx

@staticmethod
@once_differentiable
def backward(ctx, g1, g2):  # still need to accept g2
    x, idx = ctx.saved_tensors
    grad_input = torch.zeros_like(x)
    grad_input.index_add_(0, idx, g1)
    return grad_input