NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch (original) (raw)

Most deep learning frameworks, including PyTorch, train using 32-bit floating point

(FP32) arithmetic by default. However, using FP32 for all operations is not essential to achieve full accuracy for many state-of-the-art deep neural networks (DNNs). In 2017, NVIDIA researchers developed a methodology for mixed-precision training in which a few operations are executed in FP32 while the majority of the network is executed using 16-bit floating point (FP16) arithmetic. FP16 arithmetic offers the following additional performance benefits on Volta GPUs:

With mixed precision training, networks receive almost all the memory savings and improved throughput of pure FP16 training while matching the accuracy of FP32 training. A number of recently published results demonstrate the accuracy and high performance of the mixed precision recipe:

A detailed description of mixed-precision theory can be found in this GTC talk, and a PyTorch-specific discussion can be found here. In brief, the methodology is:

  1. Ensuring that weight updates are carried out in FP32.
  2. Loss scaling to prevent underflowing gradients.
  3. A few operations (e.g. large reductions) left in FP32.
  4. Everything else (the majority of the network) executed in FP16.

Mixed-Precision in PyTorch

PyTorch has comprehensive built-in support for mixed-precision training. Calling .half() on a module converts its parameters to FP16, and calling .half() on a tensor converts its data to FP16. Any operations performed on such modules or tensors will be carried out using fast FP16 arithmetic. PyTorch also has strong built-in support for NVIDIA math libraries (cuBLAS and cuDNN). These libraries use Tensor Cores to perform GEMMs (e.g., fully connected layers) and convolutions on FP16 data. For a GEMM with dimensions [M, K] x [K, N] -> [M, N], to allow cuBLAS to use Tensor Cores, there exists the additional requirement that M, K, and N be multiples of 8.

Introducing Apex

We developed Apex to streamline the mixed precision user experience and enable researchers to leverage mixed precision training in their models more conveniently. Apex is a lightweight PyTorch extension containing (among other utilities) Amp, short for Automatic Mixed-Precision. Amp enables users to take advantage of mixed precision training by adding just a few lines to their networks. Apex was released at CVPR 2018, and the current incarnation of Amp was announced at GTC San Jose 2019. Since release, Apex has seen good adoption by the PyTorch community, with nearly 3,000 stars on GitHub.

Amp emphasizes simplicity by performing relatively low-level modifications to the running model so you don’t need to worry about mixed types when writing or running your model training script. Models that use PyTorch in less common ways may find Amp’s assumptions don’t fit as well, but hooks exist to modify those assumptions as needed.

Drop-in Mixed-Precision Training: Amp

Amp provides all the benefits of mixed-precision training without any explicit management of loss scaling or type conversions.

Integrating Amp into an existing PyTorch model

The following steps are required to integrate Amp into an existing PyTorch script:

  1. Import Amp from the Apex library.
  2. Initialize Amp so it can insert the necessary modifications to the model, optimizer, and PyTorch internal functions.
  3. Mark where backpropagation (.backward()) occurs so that Amp can both scale the loss and clear per-iteration state.

Step one is a single line of code:

from apex import amp

Step 2 is also a single line of code, it requires that both the neural network model and the optimizer used for training be already defined:

model, optimizer = amp.initialize(model, optimizer)

You can pass additional options that give you finer control of how Amp adjusts the tensor and operation types.

As for step three, identify where in your code the backward pass occurs. You’ll see a few lines of code that look like the following:

loss = criterion(…) loss.backward() optimizer.step()

To enable loss scaling, you simply wrap the backward pass in the Amp context manager:

loss = criterion(…) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()

And that’s it. You can now re-run your script and have mixed-precision training enabled.

The Amp API offers additional features to handle complications like multiple optimizers, multiple backward passes, and working with custom C++ or CUDA layers not part of native PyTorch. Complete documentation can be found here.

How Amp works

At the logical level, Amp works by employing a whitelist / blacklist model. PyTorch’s tensor operations include neural network functions like torch.nn.functional.conv2d, basic math functions like torch.log, and tensor methods like torch.Tensor.__add__ called when you write a + b for two tensors). Note that these functions are a level below the neural network Module API. Modules (e.g., torch.nn.Conv2d) call into the corresponding functions for their implementation.

We divide the universe of functions into three sets:

In principle, the job of Amp is straightforward. Whenever a PyTorch function gets called, Amp checks whether it is whitelist / blacklist / neither. If whitelist, cast all arguments to FP16; if blacklist, cast all arguments to FP32; and if neither, simply ensure all arguments are of the same type (casting to the widest type if not). In practice, though, implementing the above policy is not entirely straightforward.

Capturing function calls

Because PyTorch is so flexible and dynamic (a good thing!), it lacks a static model object or graph to latch onto and insert the casts described above. Instead, Amp does so dynamically by “monkey patching” the necessary functions to intercept and cast their arguments. For example, to ensure that torch.nn.functional.linear always casts its arguments to fp16, you can write code like this:

orig_linear = torch.nn.functional.linear def wrapped_linear(*args): casted_args = [] for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): casted_args.append(torch.cast(arg, torch.float16)) else: casted_args.append(arg) return orig_linear(*casted_args) torch.nn.functional.linear = wrapped_linear

Other subtleties exist to make the code more robust (different argument types, keyword arguments), but what Amp essentially does on a call to amp.init() is insert monkey patches on all of the relevant PyTorch functions so that arguments are casted appropriately at runtime.

Minimizing casts

One additional challenge with the function-based casting approach remains. Naively applied, any sort of parameter sharing induces multiple casts of the same weight on each iteration. For example, the nn.RNNCell module will call an RNN function once for each timestep with the same FP32 weight arguments.

To ensure each weight is casted FP32 -> FP16 no more than once per iteration, Amp keeps an internal cache of any parameter casts and reuses casted versions when appropriate. The context manager around the backward pass indicates to Amp when to clear the cache at each iteration.

Get Started with Apex

Installation instructions can be found on Apex GitHub page and complete API documentation can be found here. Apex was developed in dialogue with deep learning researchers at NVIDIA and the external community. It’s an open source project, and we welcome any suggestions, feature requests, bug reports, or contributions. Feel free to submit PRs and issues on Github, or leave a comment below.

Additional Resources

Our GTC 2019 presentation covers the theory and PyTorch usage of automatic mixed precision in-depth.

The NVIDIA PyTorch deep learning examples on NGC and GitHub, as well as the examples in the Apex repository, demonstrate automatic mixed-precision in full models.