AmpOptimWrapper — mmengine 0.10.7 documentation (original) (raw)

class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, use_fsdp=False, **kwargs)[source]

A subclass of OptimWrapper that supports automatic mixed precision training based on torch.cuda.amp.

AmpOptimWrapper provides a unified interface withOptimWrapper, so AmpOptimWrapper can be used in the same way as OptimWrapper.

Warning

AmpOptimWrapper requires PyTorch >= 1.6.

Parameters:

Warning

dtype argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.

Note

If you use IterBasedRunner and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts.

backward(loss, **kwargs)[source]

Perform gradient back propagation with loss_scaler.

Parameters:

load_state_dict(state_dict)[source]

Load and parse the state dictionary of optimizer andloss_scaler.

If state_dict contains “loss_scaler.”, the loss_scaler will load the corresponding keys. Otherwise, only the optimizerwill load the state dictionary.

Parameters:

state_dict (dict) – The state dict of optimizer andloss_scaler

optim_context(model)[source]

Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.

Parameters:

model (nn.Module) – The training model.

state_dict()[source]

Get the state dictionary of optimizer andloss_scaler.

Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.

Returns:

The merged state dict of loss_scaler andoptimizer.

Return type:

dict

step(**kwargs)[source]

Update parameters with loss_scaler.

Parameters:

kwargs – Keyword arguments passed totorch.optim.Optimizer.step().