OptimWrapper — mmengine 0.10.7 documentation (original) (raw)
class mmengine.optim.OptimWrapper(optimizer, accumulative_counts=1, clip_grad=None)[source]¶
Optimizer wrapper provides a common interface for updating parameters.
Optimizer wrapper provides a unified interface for single precision training and automatic mixed precision training with different hardware. OptimWrapper encapsulates optimizer to provide simplified interfaces for commonly used training techniques such as gradient accumulative and grad clips. OptimWrapper
implements the basic logic of gradient accumulation and gradient clipping based on torch.optim.Optimizer
. The subclasses only need to override some methods to implement the mixed precision training. See more information in AmpOptimWrapper.
Parameters:
- optimizer (Optimizer) – Optimizer used to update model parameters.
- accumulative_counts (int) – The number of iterations to accumulate gradients. The parameters will be updated per
accumulative_counts
. - clip_grad (dict, optional) –
Ifclip_grad
is not None, it will be the arguments of torch.nn.utils.clip_grad_norm_() ortorch.nn.utils.clip_grad_value_().clip_grad
should be a dict, and the keys could be set as follows:
If the keytype
is not set, ortype
is “norm”, the accepted keys are as follows:- max_norm (float or int): Max norm of the gradients.
- norm_type (float or int): Type of the used p-norm. Can be
'inf'
for infinity norm. - error_if_nonfinite (bool): If True, an error is thrown if the total norm of the gradients from
parameters
isnan
,inf
, or-inf
. Defaults to False (will switch to True in the future)
If the keytype
is set to “value”, the accepted keys are as follows: - clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range
(-clip_value, +clip_value)
.
Note
If accumulative_counts
is larger than 1, performupdate_params() under the context of optim_context
could avoid unnecessary gradient synchronization.
Note
If you use IterBasedRunner
and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts
.
Note
The subclass should ensure that once update_params() is called,_inner_count += 1
is automatically performed.
Examples
Config sample of OptimWrapper and enable clipping gradient by
norm.
optim_wrapper_cfg = dict( type='OptimWrapper', _accumulative_counts=1, clip_grad=dict(max_norm=0.2))
Config sample of OptimWrapper and enable clipping gradient by
value.
optim_wrapper_cfg = dict( type='OptimWrapper', _accumulative_counts=1, clip_grad=dict(type='value', clip_value=0.2))
Use OptimWrapper to update model.
import torch.nn as nn import torch from torch.optim import SGD from torch.utils.data import DataLoader from mmengine.optim import OptimWrapper
model = nn.Linear(1, 1) dataset = torch.randn(10, 1, 1) dataloader = DataLoader(dataset) optimizer = SGD(model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer)
for data in dataloader: loss = model(data) optim_wrapper.update_params(loss)
Enable gradient accumulation
optim_wrapper_cfg = dict( type='OptimWrapper', _accumulative_counts=3, clip_grad=dict(max_norm=0.2)) ddp_model = DistributedDataParallel(model) optimizer = SGD(ddp_model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer) optim_wrapper.initialize_count_status(0, len(dataloader))
If model is a subclass instance of DistributedDataParallel,
optim_context
context manager can avoid unnecessary gradientsynchronize.
for iter, data in enumerate(dataloader): with optim_wrapper.optim_context(ddp_model): loss = model(data) optim_wrapper.update_params(loss)
backward(loss, **kwargs)[source]¶
Perform gradient back propagation.
Provide unified backward
interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, torch.cuda.amp
require some extra operation on GradScaler during backward process.
Note
If subclasses inherit from OptimWrapper
overridebackward
, _inner_count +=1
must be implemented.
Parameters:
- loss (torch.Tensor) – The loss of current iteration.
- kwargs – Keyword arguments passed to torch.Tensor.backward().
Return type:
None
initialize_count_status(model, init_counts, max_counts)[source]¶
Initialize gradient accumulation related attributes.
OptimWrapper
can be used without callinginitialize_iter_status
. However, Consider the case of len( dataloader) == 10
, and the accumulative_iter == 3
. Since 10 is not divisible by 3, the last iteration will not triggeroptimizer.step()
, resulting in one less parameter updating.
Parameters:
- model (nn.Module) – Training model
- init_counts (int) – The initial value of the inner count.
- max_counts (int) – The maximum value of the inner count.
Return type:
None
property inner_count¶
Get the number of updating parameters of optimizer wrapper.
A Context for gradient accumulation and automatic mix precision training.
If subclasses need to enable the context for mix precision training, e.g., :class:`AmpOptimWrapper
, the corresponding context should be enabled in optim_context. Since OptimWrapper
uses default fp32 training, optim_context
will only enable the context for blocking the unnecessary gradient synchronization during gradient accumulation
If model is an instance with no_sync
method (which means blocking the gradient synchronization) andself._accumulative_counts != 1
. The model will not automatically synchronize gradients if cur_iter
is divisible byself._accumulative_counts
. Otherwise, this method will enable an empty context.
Parameters:
model (nn.Module) – The training model.
Get scaled loss according to _accumulative_counts
,_inner_count
and max_counts.
Parameters:
loss (torch.Tensor) – Original loss calculated by model.
Returns:
Scaled loss.
Return type:
loss (torch.Tensor)
Decide whether the automatic gradient synchronization should be allowed at the current iteration.
It takes effect when gradient accumulation is used to skip synchronization at the iterations where the parameter is not updated.
Since should_sync
is called by optim_context(), and it is called before backward() which means self._inner_count += 1
has not happened yet. Therefore, self._inner_count += 1
should be performed manually here.
Returns:
Whether to block the automatic gradient synchronization.
Return type:
Decide whether the parameters should be updated at the current iteration.
Called by update_params() and check whether the optimizer wrapper should update parameters at current iteration.
Returns:
Whether to update parameters.
Return type:
A wrapper of Optimizer.step
.
Provide unified step
interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, torch.cuda.amp
require some extra operation on GradScaler
during step process.
Clip grad if clip_grad_kwargs
is not None, and then update parameters.
Parameters:
kwargs – Keyword arguments passed totorch.optim.Optimizer.step().
Return type:
None
update_params(loss, step_kwargs=None, zero_kwargs=None)[source]¶
Update parameters in optimizer
.
Parameters:
- loss (torch.Tensor) – A tensor for back propagation.
- step_kwargs (dict) – Arguments for optimizer.step. Defaults to None. New in version v0.4.0.
- zero_kwargs (dict) – Arguments for optimizer.zero_grad. Defaults to None. New in version v0.4.0.
Return type:
None
A wrapper of Optimizer.zero_grad
.
Provide unified zero_grad
interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic.
Parameters:
kwargs – Keyword arguments passed totorch.optim.Optimizer.zero_grad().
Return type:
None