OptimWrapperDict — mmengine 0.10.7 documentation (original) (raw)
class mmengine.optim.OptimWrapperDict(**optim_wrapper_dict)[source]¶
A dictionary container of OptimWrapper.
If runner is training with multiple optimizers, all optimizer wrappers should be managed by OptimWrapperDict which is built byCustomOptimWrapperConstructor
. OptimWrapperDict
will load and save the state dictionary of all optimizer wrappers.
Consider the semantic ambiguity of calling :meth:update_params
,backward() of all optimizer wrappers, OptimWrapperDict
will not implement these methods.
Examples
import torch.nn as nn from torch.optim import SGD from mmengine.optim import OptimWrapperDict, OptimWrapper model1 = nn.Linear(1, 1) model2 = nn.Linear(1, 1) optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1)) optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1)) optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1, model2=optim_wrapper2)
Note
The optimizer wrapper contained in OptimWrapperDict
can be accessed in the same way as dict.
Parameters:
- **optim_wrappers – A dictionary of
OptimWrapper
instance. - optim_wrapper_dict (OptimWrapper) –
backward(loss, **kwargs)[source]¶
Since OptimWrapperDict doesn’t know which optimizer wrapper’s backward method should be called (loss_scaler
maybe different in different :obj:AmpOptimWrapper), this method is not implemented.
The optimizer wrapper of OptimWrapperDict should be accessed and call its backward.
Parameters:
loss (Tensor) –
Return type:
None
Get the learning rate of all optimizers.
Returns:
Learning rate of all optimizers.
Return type:
Get the momentum of all optimizers.
Returns:
momentum of all optimizers.
Return type:
initialize_count_status(model, cur_iter, max_iters)[source]¶
Do nothing but provide unified interface for OptimWrapper
Since OptimWrapperDict
does not know the correspondence between model and optimizer wrapper. initialize_iter_status
will do nothing and each optimizer wrapper should call initialize_iter_status
separately.
Parameters:
model (Module) –
Return type:
None
A generator to get the name and corresponding OptimWrapper
Return type:
Iterator[Tuple[str, OptimWrapper]]
A generator to get the name of OptimWrapper
Return type:
load_state_dict(state_dict)[source]¶
Load the state dictionary from the state_dict
.
Parameters:
state_dict (dict) – Each key-value pair in state_dict represents the name and the state dictionary of correspondingOptimWrapper.
Return type:
None
optim_context
should be called by each optimizer separately.
Parameters:
model (Module) –
property param_groups¶
Returns the parameter groups of each OptimWrapper.
Get the state dictionary of all optimizer wrappers.
Returns:
Each key-value pair in the dictionary represents the name and state dictionary of corresponding OptimWrapper.
Return type:
Since the backward method is not implemented, the step should not be implemented either.
Return type:
None
update_params(loss, step_kwargs=None, zero_kwargs=None)[source]¶
Update all optimizer wrappers would lead to a duplicate backward errors, and OptimWrapperDict does not know which optimizer wrapper should be updated.
Therefore, this method is not implemented. The optimizer wrapper of OptimWrapperDict should be accessed and call its update_params.
Parameters:
Return type:
None
A generator to get OptimWrapper
Return type:
Set the gradients of all optimizer wrappers to zero.
Return type:
None