TorchDynamo Update 6: Training support with AOTAutograd (original) (raw)

Recap

We are working on an experimental project called TorchDynamo. TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with an ensemble of different backends and autotuning. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds: usability and performance.

If you are new here the TorchDynamo README is a good place to start, you can also catch up on our prior posts:

Adding Training in TorchDynamo

The biggest change since last time has been adding training support with AOTAutograd.

Training adds challenges because the PyTorch Automatic Differentiation engine sits below the PyTorch dispatcher in C++. Therefore, the operators running in the backward pass are not directly visible to TorchDynamo at the Python level.

To support training with TorchDynamo, we need to run/optimize operations that happen in the .backward() pass. Supporting backwards can be done in a few different ways:

AOTAutograd

AOTAutograd relies on the recently introduced torch_dispatch based tracing mechanism to capture the backward graph ahead of time. Therefore, it reuses the PyTorch core Autograd engine to generate the backward graph. This is a benefit over the parallel implementation of symbolic autodiff in TorchScript, which is difficult to maintain over time.

Additionally, AOTAutograd design allows easier joint forward-backward graph optimizations such as activation checkpointing (please refer to this post for more details). This results in two separate graphs - one for forward and other for backward. We then compile these graphs separately using the backend compilers. In this post, AOTAutograd uses TorchScript with NNC/nvFuser for compiling the generated forward and backward graphs.

Results

We tested this integration on TorchBench models on NVIDIA A100 GPUs. Our training measurement is forward() + loss calculation + backward(), where loss calculation is just mean() as a placeholder for measurement. This is not full training because it does not include the optimizer. We also set the models in eval mode. This helps us in removing randomness from operations like Dropout and performing accuracy tests. This is done so that we can verify accuracy and maintain confidence in the correctness of our benchmarks, although there are some divergences with actual training models. For example, batch norm is now fusible as a pointwise operator, while dropout is no longer fusible.

We check the numerical accuracy by comparing the computed gradients. We measure both latency and peak memory footprint of the training iteration. The table below shows speedup and memory savings of different configurations normalized to eager performance. For AOTAutograd, we use the min-cut recompuation algorithm as discussed in this post.

forward()+backward() improvement over eager mode on NVIDIA A100

Some of the observations are

Outstanding Issues

There is still some work remaining to get the accuracy passing on all the TorchBench models. In the table above, we see 4 models failing. There are a few more models skipped here.

The failures are spread across different components (TorchDynamo, AOTAutograd, TorchBench, TorchScript and nvFuser). The running list of these issues is here. Because there are many components here, this integration exercise has revealed bugs/issues across the components. Some of them are

Next Steps

While there are still many outstanding issues, these results give us confidence that AOTAutograd+TorchDynamo can deliver speedups for training. Looking further ahead, there are other bigger/complex topics like supporting dynamic shapes and distributed training. These are dependent on ongoing efforts in PyTorch, but we are incredibly optimistic about recent progress there and look forward to an exciting future!

This is a joint collaboration with @jansel and @Chillee