apex.amp — Apex 0.1.0 documentation (original) (raw)

Apex

This page documents the updated API for Amp (Automatic Mixed Precision), a tool to enable Tensor Core-accelerated training in only 3 lines of Python.

A runnable, comprehensive Imagenet example demonstrating good practices can be found on the Github page.

GANs are a tricky case that many people have requested. A comprehensive DCGAN exampleis under construction.

If you already implemented Amp based on the instructions below, but it isn’t behaving as expected, please review Advanced Amp Usage to see if any topics match your use case. If that doesn’t help,file an issue.

opt_levels and Properties

Amp allows users to easily experiment with different pure and mixed precision modes. Commonly-used default modes are chosen by selecting an “optimization level” or opt_level; each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training. Finer-grained control of how a given opt_level behaves can be achieved by passing values for particular properties directly to amp.initialize. These manually specified values override the defaults established by the opt_level.

Example:

Declare model and optimizer as usual, with default (FP32) precision

model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

Allow Amp to perform casts as required by the opt_level

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") ...

loss.backward() becomes:

with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() ...

Users should not manually cast their model or data to .half(), regardless of what opt_levelor properties are chosen. Amp intends that users start with an existing default (FP32) script, add the three lines corresponding to the Amp API, and begin training with mixed precision. Amp can also be disabled, in which case the original script will behave exactly as it used to. In this way, there’s no risk adhering to the Amp API, and a lot of potential performance benefit.

Note

Because it’s never necessary to manually cast your model (aside from the call amp.initialize) or input data, a script that adheres to the new API can switch between different opt-levels without having to make any other changes.

Properties

Currently, the under-the-hood properties that govern pure or mixed precision training are the following:

Again, you often don’t need to specify these properties by hand. Instead, select an opt_level, which will set them up for you. After selecting an opt_level, you can optionally pass property kwargs as manual overrides.

If you attempt to override a property that does not make sense for the selected opt_level, Amp will raise an error with an explanation. For example, selecting opt_level="O1" combined with the override master_weights=True does not make sense. O1 inserts casts around Torch functions rather than model weights. Data, activations, and weights are recast out-of-place on the fly as they flow through patched functions. Therefore, the model weights themselves can (and should) remain FP32, and there is no need to maintain separate FP32 master weights.

opt_levels

Recognized opt_levels are "O0", "O1", "O2", and "O3".

O0 and O3 are not true mixed precision, but they are useful for establishing accuracy and speed baselines, respectively.

O1 and O2 are different implementations of mixed precision. Try both, and see what gives the best speedup and accuracy for your model.

O0: FP32 training

Your incoming model should be FP32 already, so this is likely a no-op.O0 can be useful to establish an accuracy baseline.

Default properties set by O0:

cast_model_type=torch.float32

patch_torch_functions=False

keep_batchnorm_fp32=None (effectively, “not applicable,” everything is FP32)

master_weights=False

loss_scale=1.0

Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist model. Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed in FP16. Blacklist ops that benefit from FP32 precision (for example, softmax) are performed in FP32. O1 also uses dynamic loss scaling, unless overridden.

Default properties set by O1:

cast_model_type=None (not applicable)

patch_torch_functions=True

keep_batchnorm_fp32=None (again, not applicable, all model weights remain FP32)

master_weights=None (not applicable, model weights remain FP32)

loss_scale="dynamic"

O2: “Almost FP16” Mixed Precision

O2 casts the model weights to FP16, patches the model’s forward method to cast input data to FP16, keeps batchnorms in FP32, maintains FP32 master weights, updates the optimizer’s param_groups so that the optimizer.step()acts directly on the FP32 weights (followed by FP32 master weight->FP16 model weight copies if necessary), and implements dynamic loss scaling (unless overridden). Unlike O1, O2 does not patch Torch functions or Tensor methods.

Default properties set by O2:

cast_model_type=torch.float16

patch_torch_functions=False

keep_batchnorm_fp32=True

master_weights=True

loss_scale="dynamic"

O3: FP16 training

O3 may not achieve the stability of the true mixed precision options O1 and O2. However, it can be useful to establish a speed baseline for your model, against which the performance of O1 and O2 can be compared. If your model uses batch normalization, to establish “speed of light” you can try O3 with the additional property overridekeep_batchnorm_fp32=True (which enables cudnn batchnorm, as stated earlier).

Default properties set by O3:

cast_model_type=torch.float16

patch_torch_functions=False

keep_batchnorm_fp32=False

master_weights=False

loss_scale=1.0

Unified API

apex.amp. initialize(models, optimizers=None, enabled=True, opt_level='O1', cast_model_type=None, patch_torch_functions=None, keep_batchnorm_fp32=None, master_weights=None, loss_scale=None, cast_model_outputs=None, num_losses=1, verbosity=1, min_loss_scale=None, max_loss_scale=16777216.0)[source]

Initialize your models, optimizers, and the Torch tensor and functional namespace according to the chosen opt_level and overridden properties, if any.

amp.initialize should be called after you have finished constructing your model(s) and optimizer(s), but before you send your model through any DistributedDataParallel wrapper. See Distributed training in the Imagenet example.

Currently, amp.initialize should only be called once, although it can process an arbitrary number of models and optimizers (see the corresponding Advanced Amp Usage topic). If you think your use case requires amp.initialize to be called more than once,let us know.

Any property keyword argument that is not None will be interpreted as a manual override.

To prevent having to rewrite anything else in your script, name the returned models/optimizers to replace the passed models/optimizers, as in the code sample below.

Parameters

Returns

Model(s) and optimizer(s) modified according to the opt_level. If either the models or optimizers args were lists, the corresponding return value will also be a list.

Permissible invocations:

model, optim = amp.initialize(model, optim,...) model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...) [model1, model2], optim = amp.initialize([model1, model2], optim,...) [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)

This is not an exhaustive list of the cross product of options that are possible,

just a set of examples.

model, optim = amp.initialize(model, optim, opt_level="O0") model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")

model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")

model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0") model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")

model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0") model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")

The Imagenet example demonstrates live use of various opt_levels and overrides.

apex.amp. scale_loss(loss, optimizers, loss_id=0, model=None, delay_unscale=False, delay_overflow_check=False)[source]

On context manager entrance, creates scaled_loss = (loss.float())*current loss scale.scaled_loss is yielded so that the user can call scaled_loss.backward():

with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()

On context manager exit (if delay_unscale=False), the gradients are checked for infs/NaNs and unscaled, so that optimizer.step() can be called.

Note

If Amp is using explicit FP32 master params (which is the default for opt_level=O2, and can also be manually enabled by supplying master_weights=True to amp.initialize) any FP16 gradients are copied to FP32 master gradients before being unscaled.optimizer.step() will then apply the unscaled master gradients to the master params.

Warning

If Amp is using explicit FP32 master params, only the FP32 master gradients will be unscaled. The direct .grad attributes of any FP16 model params will remain scaled after context manager exit. This subtlety affects gradient clipping. See “Gradient clipping” underAdvanced Amp Usage for best practices.

Parameters

Warning

If delay_unscale is True for a given backward pass, optimizer.step() cannot be called yet after context manager exit, and must wait for another, later backward context manager invocation with delay_unscale left to False.

apex.amp. master_params(optimizer)[source]

Generator expression that iterates over the params owned by optimizer.

Parameters

optimizer – An optimizer previously returned from amp.initialize.

Checkpointing

To properly save and load your amp training, we introduce the amp.state_dict(), which contains all loss_scalers and their corresponding unskipped steps, as well as amp.load_state_dict() to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:

Initialization

opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

Train your model

...

Save checkpoint

checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() } torch.save(checkpoint, 'amp_checkpoint.pt') ...

Restore

model = ... optimizer = ... checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp'])

Continue training

...

Note that we recommend restoring the model using the same opt_level. Also note that we recommend calling the load_state_dict methods after amp.initialize.

Advanced use cases

The unified Amp API supports gradient accumulation across iterations, multiple backward passes per iteration, multiple models/optimizers, custom/user-defined autograd functions, and custom data batch classes. Gradient clipping and GANs also require special treatment, but this treatment does not need to change for different opt_levels. Further details can be found here:

Transition guide for old API users

We strongly encourage moving to the new Amp API, because it’s more versatile, easier to use, and future proof. The original FP16_Optimizer and the old “Amp” API are deprecated, and subject to removal at at any time.

For users of the old “Amp” API

In the new API, opt-level O1 performs the same patching of the Torch namespace as the old thing called “Amp.” However, the new API allows static or dynamic loss scaling, while the old API only allowed dynamic loss scaling.

In the new API, the old call to amp_handle = amp.init(), and the returned amp_handle, are no longer exposed or necessary. The new amp.initialize() does the duty of amp.init() (and more). Therefore, any existing calls to amp_handle = amp.init() should be deleted.

The functions formerly exposed through amp_handle are now free functions accessible through the amp module.

The backward context manager must be changed accordingly:

old API

with amp_handle.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() ->

new API

with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()

For now, the deprecated “Amp” API documentation can still be found on the Github README: https://github.com/NVIDIA/apex/tree/master/apex/amp. The old API calls that annotate user functions to run with a particular precision are still honored by the new API.

For users of the old FP16_Optimizer

opt-level O2 is equivalent to FP16_Optimizer with dynamic_loss_scale=True. Once again, the backward pass must be changed to the unified version:

optimizer.backward(loss) -> with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()

One annoying aspect of FP16_Optimizer was that the user had to manually convert their model to half (either by calling .half() on it, or using a function or module wrapper fromapex.fp16_utils), and also manually call .half() on input data. **Neither of these are necessary in the new API. No matter what –opt-level you choose, you can and should simply build your model and pass input data in the default FP32 format.**The new Amp API will perform the right conversions duringmodel, optimizer = amp.initialize(model, optimizer, opt_level=....) based on the --opt-leveland any overridden flags. Floating point input data may be FP32 or FP16, but you may as well just let it be FP16, because the model returned by amp.initialize will have its forwardmethod patched to cast the input data appropriately.