AmpOptimWrapper — mmengine 0.10.7 documentation (original) (raw)
class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, use_fsdp=False, **kwargs)[source]¶
A subclass of OptimWrapper that supports automatic mixed precision training based on torch.cuda.amp.
AmpOptimWrapper
provides a unified interface withOptimWrapper
, so AmpOptimWrapper
can be used in the same way as OptimWrapper
.
Warning
AmpOptimWrapper
requires PyTorch >= 1.6.
Parameters:
- loss_scale (float or str or dict) –
The initial configuration oftorch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults todynamic
.- ”dynamic”: Initialize GradScale without any arguments.
- float: Initialize GradScaler with
init_scale
. - dict: Initialize GradScaler with more detail configuration.
- dtype (str or torch.dtype, optional) – The data type to autocast in amp. If a
str
is given, it will be converted totorch.dtype
. Validstr
format are ‘float16’, ‘bfloat16’, ‘float32’ and‘float64’. If set toNone
, the default data type will be used. Defaults to None.New in version 0.6.1. - use_fsdp (bool) – Using
ShardedGradScaler
when it is True. It should be enabled when usingFullyShardedDataParallel
. Defaults to False.New in version 0.8.0. - **kwargs – Keyword arguments passed to OptimWrapper.
Warning
dtype
argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.
Note
If you use IterBasedRunner
and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts
.
backward(loss, **kwargs)[source]¶
Perform gradient back propagation with loss_scaler
.
Parameters:
- loss (torch.Tensor) – The loss of current iteration.
- kwargs – Keyword arguments passed to torch.Tensor.backward()
load_state_dict(state_dict)[source]¶
Load and parse the state dictionary of optimizer
andloss_scaler
.
If state_dict contains “loss_scaler.”, the loss_scaler
will load the corresponding keys. Otherwise, only the optimizer
will load the state dictionary.
Parameters:
state_dict (dict) – The state dict of optimizer
andloss_scaler
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
Parameters:
model (nn.Module) – The training model.
Get the state dictionary of optimizer
andloss_scaler
.
Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
Returns:
The merged state dict of loss_scaler
andoptimizer
.
Return type:
Update parameters with loss_scaler
.
Parameters:
kwargs – Keyword arguments passed totorch.optim.Optimizer.step().