Double Backward with Custom Functions — PyTorch Tutorials 2.7.0+cu126 documentation (original) (raw)

intermediate/custom_function_double_backward_tutorial

Run in Google Colab

Colab

Download Notebook

Notebook

View on GitHub

GitHub

Created On: Aug 13, 2021 | Last Updated: Aug 13, 2021 | Last Verified: Nov 05, 2024

It is sometimes useful to run backwards twice through backward graph, for example to compute higher-order gradients. It takes an understanding of autograd and some care to support double backwards, however. Functions that support performing backward a single time are not necessarily equipped to support double backward. In this tutorial we show how to write a custom autograd function that supports double backward, and point out some things to look out for.

When writing a custom autograd function to backward through twice, it is important to know when operations performed in a custom function are recorded by autograd, when they aren’t, and most importantly, howsave_for_backward works with all of this.

Custom functions implicitly affects grad mode in two ways:

Next, to understand how save_for_backward interacts with the above, we can explore a couple examples:

Saving the Inputs

Consider this simple squaring function. It saves an input tensor for backward. Double backward works automatically when autograd is able to record operations in the backward pass, so there is usually nothing to worry about when we save an input for backward as the input should have grad_fn if it is a function of any tensor that requires grad. This allows the gradients to be properly propagated.

import torch

class Square(torch.autograd.Function): @staticmethod def forward(ctx, x): # Because we are saving one of the inputs use save_for_backward # Save non-tensors and non-inputs/non-outputs directly on ctx ctx.save_for_backward(x) return x**2

@staticmethod
def backward(ctx, grad_out):
    # A function support double backward automatically if autograd
    # is able to record the computations performed in backward
    x, = ctx.saved_tensors
    return grad_out * 2 * x

Use double precision because finite differencing method magnifies errors

x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) torch.autograd.gradcheck(Square.apply, x)

Use gradcheck to verify second-order derivatives

torch.autograd.gradgradcheck(Square.apply, x)

We can use torchviz to visualize the graph to see why this works

import torchviz

x = torch.tensor(1., requires_grad=True).clone() out = Square.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x) And the graph of this function has been properly constructed

https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png

Saving the Outputs

A slight variation on the previous example is to save an output instead of input. The mechanics are similar because outputs are also associated with a grad_fn.

class Exp(torch.autograd.Function): # Simple case where everything goes well @staticmethod def forward(ctx, x): # This time we save the output result = torch.exp(x) # Note that we should use save_for_backward here when # the tensor saved is an ouptut (or an input). ctx.save_for_backward(result) return result

@staticmethod
def backward(ctx, grad_out):
    result, = ctx.saved_tensors
    return result * grad_out

x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()

Validate our gradients using gradcheck

torch.autograd.gradcheck(Exp.apply, x) torch.autograd.gradgradcheck(Exp.apply, x)

Use torchviz to visualize the graph:

out = Exp.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png

When Backward is not Tracked

Finally, let’s consider an example when it may not be possible for autograd to track gradients for a functions backward at all. We can imagine cube_backward to be a function that may require a non-PyTorch library like SciPy or NumPy, or written as a C++ extension. The workaround demonstrated here is to create another custom function CubeBackward where you also manually specify the backward of cube_backward!

def cube_forward(x): return x**3

def cube_backward(grad_out, x): return grad_out * 3 * x**2

def cube_backward_backward(grad_out, sav_grad_out, x): return grad_out * sav_grad_out * 6 * x

def cube_backward_backward_grad_out(grad_out, x): return grad_out * 3 * x**2

class Cube(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return cube_forward(x)

@staticmethod
def backward(ctx, grad_out):
    x, = ctx.saved_tensors
    return CubeBackward.apply(grad_out, x)

class CubeBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_out, x): ctx.save_for_backward(x, grad_out) return cube_backward(grad_out, x)

@staticmethod
def backward(ctx, grad_out):
    x, sav_grad_out = ctx.saved_tensors
    dx = cube_backward_backward(grad_out, sav_grad_out, x)
    dgrad_out = cube_backward_backward_grad_out(grad_out, x)
    return dgrad_out, dx

x = torch.tensor(2., requires_grad=True, dtype=torch.double)

torch.autograd.gradcheck(Cube.apply, x) torch.autograd.gradgradcheck(Cube.apply, x)

Use torchviz to visualize the graph:

out = Cube.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png

To conclude, whether double backward works for your custom function simply depends on whether the backward pass can be tracked by autograd. With the first two examples we show situations where double backward works out of the box. With the third and fourth examples, we demonstrate techniques that enable a backward function to be tracked, when they otherwise would not be.