Automatic mixed precision for Pytorch (original) (raw)

๐Ÿš€ Feature

We would like Pytorch to support the automatic mixed precision training recipe: auto-casting of Cuda operations to FP16 or FP32 based on a whitelist-blacklist model of what precision is best for each operation, as well as gradient scaling to avoid underflow during the backward pass.

Work in progress

Motivation

Mixed precision is essential to achieve good performance on Tensor Core GPUs (Volta + Turing). We've been trying to educate and evangelize for it since Volta's release. We have Apex, our own repository of tools for mixed precision training, which has seen moderate adoption. However, forcing users to install a separate toolkit is burdensome, especially if their environments aren't set up to build extensions, and we'll never be able to achieve the same performance as a native integration. Native, well documented support in Pytorch core is certainly the most convenient way to enable mixed precision training for the largest number of users.

Pitch

After initial discussions with @mruberry, @zdevito, @jamesr66a, @gchanan, and @jjsjann123 we believe the UX should permit auto-casting and gradient scaling as modular and independent components.

Auto-casting

Background

The current philosophy of Apex's Automatic Mixed Precision (Amp, in recommended mode "O1") is that when training with mixed precision, the user never needs to manually alter the precision of their model or data. The user declares their model in default (FP32) precision. The model parameters (leaves) are and remain FP32 for the duration of training. These leaves are also directly owned and stepped by the optimizer, which is identical to the ordinary behavior of Pytorch in the absence of mixed precision.

To ensure that FP16 is used for operations that benefit from it, Tensor.* methods and torch.* and torch.nn.functional.* functions are patched to cast data to a certain type before running the actual op. Which type depends on what precision is best for that op. For example, torch.mm is patched to cast the incoming input and weight to FP16, which enables Tensor Cores. torch.log is patched to cast input to fp32, because log's forward and backward may require a large dynamic range. This casting-as-data-flows-through-functions is the strategy used by Amp in Apex today, and achieves accuracy comparable to pure FP32 training on a wide range of networks. It is also the strategy used by MXNet and Tensorflow's Amp integrations. However, Apex's Amp is implemented by Python-side monkey-patching of torch.* and torch.nn.functional.* functions, which is invasive and not ideal for performance.

Proposed Implementation

For eager execution, we propose to integrate the same casting-as-data-flows-through-functions strategy that Apex's Amp uses. We propose to implement this by inserting the casts in some of the autogenerated C++ functions on the Cuda dispatch path. Each such function will be given a few additional lines by the autogeneration script. For example, a whilelist function with 1 argument would be given something like

if(autocasting_is_enabled()) input = input.half()

These casts should be autograd-exposed, so they will be reversed in backward(). They should also precede the autograd-exposed call of the whitelist or blacklist op itself, so if the op saves its inputs for backward, the inputs are saved as the correct (post-cast) type.

On the Python side, the user shall request auto-casting by running the forward pass (or any portions of the forward pass where auto-casting is desired) under a nestable context manager, for example

@contextlib.contextmanager def autocast(whitelist_type=torch.float16, enabled=True): old_whitelist_type, old_status = torch.get_autocasting_state() torch.set_autocasting_state(whitelist_type, enabled) try: yield finally: torch.set_autocasting_state(original_whitelist_type, old_status)

torch.get_ and set_autocasting_state will get/set a backend state that is queryable within C++ dispatch functions.

whitelist_type can be changed to request that whitelist functions be autocast to types other than FP16. The enabled argument can be used to locally disable autocasting if the user wishes to have manual control over the types that are used in some regions of their model, while permitting autocasting in others.

My initial thought is that with autocast() may be used to wrap any regions of code where a graph is being constructed via explicit Python invocations of Pytorch ops (ie, the forward pass), but shall not wrap any region where a previously constructed graph is being backwarded through. All desired casting will have been recorded as part of the graph construction, and will be correctly reversed by a bare backward call without needing to be under an autocast context. Backward passes with create_graph=True also belong to the latter category (ie, they should not be under a with autocast() context). Gradient Penalty under End to End Examples below shows this philosophy more clearly.

It's possible that running both forward and backward under the context manager won't do any harm (ie, any resulting casts requested during backward will be no-ops, because the type flow is already being properly/faithfully reversed by autograd) and it can be permissible (but not required) to also allow the backward pass to take place under the context manager.

When training with FP16, gradient scaling and auto-casting must both be requested by the user. We envision that when autocasting to formats other than FP16, gradient scaling will not be necessary, and the auto-casting context manager can be used without any gradient scaling.

Example Usage

with autocast(): output = model(input) loss = loss_fn(output, target)

The backward pass should be invoked outside the context manager. All casting has been appropriately recorded as part of the forward pass.

Within model.forward, if the user has regions where they wish explicit control over the precision, they may nest an invocation of with autocast(enabled=False):

def forward(self, x): x = self.layer_permitting_autocasting(x) with autocast(enabled=False): x = x.float() x = self.explicitly_float_layer(x) x = self.another_layer_permitting_autocasting(x) return x

Gotchas/Challenges

The Amp casts need to be recorded by autograd, so they'll be properly reversed in backward. Unfortunately, for many ops, dispatch into an autograd-disabled region currently occurs in VariableType*.cpp, at a higher level of the call chain than the Cuda-specific dispatch functions. VariableType*.cpp is also the level at which necessary data is saved for backward. In other words, by the time we've reached the Cuda-specific dispatch functions, it's too late to invoke autograd-exposed casts. The alternative is to insert the

if(autocasting_is_enabled()) input = input.half() or float()

snippets at the level of VariableType*.cpp, before we dive into autograd-disabled regions, but then these if statements will be on the hot path taken by all dispatches (Cuda and non-Cuda). I'd like to avoid having the if statement on any non-Cuda code path. This is a tricky problem and I need to think hard about it.

Gradient scaling

Background

Late in training, FP16 gradients can underflow, halting convergence and in some cases causing destabilization. Apex's Amp mitigates underflow via "dynamic gradient scaling." The implementation creates scaled gradients by handing the user scaled_loss = loss*scale_factor, then requiring that the user invoke scaled_loss.backward(). By the chain rule, all gradients flowing backward through the network are then scaled by scale_factor. Apex's Amp attempts to maximize use of FP16's full dynamic range by choosing the highest scale_factor that can be used without incurring inf/nan gradients, which is accomplished as follows: Initially, a high scale_factor is chosen. Each iteration, after backward() returns, gradients are checked for infs/nans. If any infs/nans are found, the optimizer skips the step, and the scale factor is reduced. The scale factor is also periodically increased if a streak of successful (inf/nan free) iterations occurs. Gradients are unscaled in FP32 before being applied to FP32 model parameters.

Proposed API

User scripts will implement gradient scaling with an instance of a helper class. Typical use would look like

scaler = torch.cuda.amp.AmpScaler()

for input, target in data: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

Creating scaled gradients with torch.autograd.backward would look like

torch.autograd.backward(scaler.scale((output0, output1)), grad_tensors=(grad0, grad1))

Creating scaled gradients with torch.autograd.grad would look like

torch.autograd.grad(scaler.scale((output0, output1)), model.parameters(), grad_outputs=(grad0, grad1))

The best explanation for AmpScaler is the class itself, as found in the gradient scaling PR. It's lightweight and each function is documented.

scaler internally initializes and maintains the scale value, and updates it each iteration based on whether optimizer's gradients contained infs or NaNs. If infs/NaN gradients are encountered in a given iteration, scaler.update reduces the scale. If no inf/NaN gradients are encountered, scaler.update increases the scale slightly. This approach achieves what dynamic gradient scaling intends: over time, riding the edge of the highest gradient scale that can be used without incurring overflow.

The user also has the freedom to manually reset the scale value at the end of any iteration, by passing a new scale to scaler.update as a Tensor or Python float.

AmpScaler instances contain at most a few Tensors and Python scalars. When checkpointing, the user can easily save an AmpScaler instance as part of the checkpoint, alongside model.state_dict() and optimizer.state_dict(). The model and optimizer instance(s) are not affected by AmpScaler, and remain exactly what they would be in the absence of mixed precision. Therefore, behaviors and invocations of model.state_dict() and optimizer.state_dict() themselves may remain unaltered.

Interaction with Existing Optimizers

As long as training scripts adhere to the proposed API, existing optimizers (whether native or custom) will immediately work with mixed precision. In particular, the gradient scaling API does not rely on changes to the Python source of step() in any existing optimizers.

scaler.step(optimizer) wraps optimizer.step() with logic to make it scaling-safe. Specifically, scaler.step(optimizer)

optimizer.step() itself is untouched, and again, does not need to change for existing optimizers.

Interaction with New Custom Optimizers

AmpScaler defines a contract such that custom optimizers may implement their own scaling-safe step methods if they choose to. If an optimizer obeys this contract, AmpScaler.step will call the optimizer's step method directly. This gives custom optimizer authors a control point to implement ninja optimizations like sync-free dynamic gradient scaling.

Gotchas/Challenges

When the user invokes a backward pass with a scale factor, all gradients produced by this backward pass (leaf gradients produced by loss.backward or torch.autograd.backward, or out-of-place gradients produced by torch.autograd.grad) will be scaled. Therefore, anything that manipulates the gradients between scaler.scale(loss).backward() and the scaler.step(optimizer) will require proper awareness of the scale factor. Examples of operations that require scale-factor-aware treatment are

In our opinion, requiring the user be aware of the scale factor when making direct use of the gradients is not terribly burdensome. The user has explicitly requested gradient scaling; they should not be surprised when they end up with scaled gradients on their hands. The treatment of such cases is also not difficult from a code-change perspective, as long as it is clearly documented.

Closures are also a challenge. Do we need to support explicit closure use? Based on the scripts and issues I've seen, closure use is not terribly common, and LBFGS is the only native optimizer that requires a closure. However, I think I can torture the proposed API into supporting closures if it turns out to be in high demand.

End to End Examples (Auto-Casting + Gradient Scaling)

Typical Use (1 loss, 1 optimizer)

scaler = AmpScaler() ... for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

Gradient Clipping

Gradient clipping requires awareness that the gradients resulting from scaler.scale(loss).backward() are scaled. One simple way to account for the scale factor is by clipping to max_norm*scaler.get_scale() instead of max_norm:

scaler = AmpScaler() ... for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward()

# Gradients are scaled, so we clip to max_norm*scale
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm*scaler.get_scale())

scaler.step(optimizer)
scaler.update()

Here the scaled gradients are clipped. scaler.step(optimizer) is aware that gradients have not yet been unscaled, and unscales them under the hood before calling optimizer.step().

Gradient Clipping with Explicit Unscaling

The specific case of clipping scaled gradients isnโ€™t so hard (all you have to do is clip to max_norm*scaler.get_scale()). However, in general, between the backward pass and the optimizer step you may wish to manipulate gradients in some way thatโ€™s not so easy to translate to scaled gradients. In such cases, you can unscale and step separately. Hereโ€™s how that looks, using gradient clipping as an example once more:

scaler = AmpScaler() ... for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward()

scaler.unscale(optimizer)
# Since the optimizer's owned gradients are unscaled, we can clip to max_norm directly:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

scaler.step(optimizer)
scaler.update()

Gradient Penalty

(based on Higher order gradients from the release notes)

Gradient penalty also requires awareness that the gradients are scaled in certain places. Additionally, gradient penalty demonstrates:

The following shows an implementation of gradient penalty under the proposed API.

scaler = AmpScaler() ... for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target)

# We should scale outputs for the out-of-place backward pass
grad_params = torch.autograd.grad(scaler.scale(loss), model.parameters(), create_graph=True)

# In general, the penalty term may depend nonlinearly on the out-of-place gradients, so to be safe,
# manually unscale them before computing the penalty.  This unscale should be autograd-exposed.
grad_params = [p*(1./scaler.get_scale()) for p in grad_params]

# Compute the penalty term and add it to the loss.
# The penalty term computation is effectively another snippet of forward pass, so it makes
# sense to enable autocasting for this section as well:
with autocast():
    grad_norm = 0
    for grad in grad_params:
        grad_norm += grad.pow(2).sum()
    grad_norm = grad_norm.sqrt()
    loss = loss + grad_norm

# The usual scaling for backward will now accumulate leaf gradients that are appropriately scaled.
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Gradient penalty is a tricky case to think about but writing the code is simple once the right pattern is established. Compared to the example in the release notes, the only extra line for gradient penalty computation is the unscaling grad_params = [p*(1./scaler.get_scale()) for p in grad_params]. I think this can be considered a documentation problem, and addressed by providing clear examples.

Multiple Models/Optimizers/Losses

Networks must use the same AmpScaler instance (and therefore the same scale) to create gradients for all backward passes in a given iteration, otherwise we open the door to nasty corner cases. For example, if two different losses, with different gradient scales, accumulate into the same parameters' .grads, the accumulation math breaks. If two different losses, with different gradient scales, accumulate into different parameters owned by the same optimizer, then when you invoke scaler.unscale(optimizer), there's no single correct value that can be used to unscale all the gradients owned by that optimizer, and handling multiple scale factors for different parameters within the same optimizer would get ugly fast. Requiring that networks use the same AmpScaler instance for all backward passes avoids all such control flow difficulties, while still achieving what loss scaling is meant to achieve.

scaler.update() must be called only at the end of the iteration, after scaler.step(optimizer) has been called for all optimizers used this iteration. This requirement allows update to account for infs/nans found among any of the optimizers' gradients.

scaler = torch.cuda.amp.AmpScaler() ... for input, target in data: optimizer0.zero_grad() optimizer1.zero_grad() with autocast(): output0 = model0(input) output1 = model1(input) loss0 = loss_fn(2 * output0 + 3 * output1, target) loss1 = loss_fn(3 * output0 - 5 * output1, target)

scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()

# Users can choose which optimizers receive explicit unscaling
scaler.unscale(optimizer0)

scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()

I had to write Apex's Amp to handle arbitrary combinations of multiple models/optimizers/losses. I'm painfully aware of the complicated combinations of models/optimizers/losses people want to implement. In my opinion, the proposed interface permits a great deal of flexibility in network design.

Gradient accumulation

Gradient accumulation across iterations (between steps) is a common use case. The proposed API accommodates gradient accumulation without trouble:

scaler = AmpScaler() ... for i, (input, target) in enumerate(data): with autocast(): output = model(input) loss = loss_fn(output, target) loss = loss/iters_to_accumulate scaler.scale(loss).backward() if (i + 1) % iters_to_accumulate == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad()

Switching automatic mixed precision on and off

If users want to run with or without autocasting+gradient scaling, they shouldn't have to litter their code with if statements. The API should allow one code path that accommodates easily switching autocasting+gradient scaling on and off.

The autocasting context manager and AmpScaler constructor provide such convenience by accepting an enabled=False argument.

In the following example, autocasting and gradient scaling can be switched on and off by flipping args.use_mixed_precision with no additional control flow required.

scaler = AmpScaler(enabled=args.use_mixed_precision) ... for input, target in data: optimizer.zero_grad() with autocast(enabled=args.use_mixed_precision): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

Batch replay

Sometimes every iteration/batch is valuable enough that users don't want to skip any. Instead, it's preferable to replay the batch with a reduced loss scale until gradients do not contain infs/nans. Batch replay control flow is not provided by the API alone, but with the proposed gradient scaling PR, it would be easy to rig:

scaler = AmpScaler() ... for input, target in data: # Replay the batch, updating the scale if necessary, until we receive gradients that aren't inf/nan. while True: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.unscale(optimizer) if scaler._found_inf(optimizer).item(): scaler.update() else: break scaler.step(optimizer) scaler.update()

Alternatives

Python-side alternatives for gradient scaling and unscaling

The supplementary information contains an in-depth discussion of some alternatives I considered for the gradient scaling and gradient unscaling API.

Gradient scaling in the autograd backend

I recently submitted a PR that implemented gradient scaling directly in the autograd engine (Engine::execute).

Benefits:

Drawbacks:

cc @ezyang @gchanan @vincentqb