EMAHook — mmengine 0.10.7 documentation (original) (raw)
class mmengine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[source]¶
A Hook to apply Exponential Moving Average (EMA) on the model during training.
Note
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after train.
begin_iter
andbegin_epoch
cannot be set at the same time.
Parameters:
- ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in
mmengine.model.averaged_model
. Defaults to ‘ExponentialMovingAverage’. - strict_load (bool) – Whether to strictly enforce that the keys of
state_dict
in checkpoint match the keys returned byself.module.state_dict
. Defaults to False. Changed in v0.3.0. - begin_iter (int) – The number of iteration to enable
EMAHook
. Defaults to 0. - begin_epoch (int) – The number of epoch to enable
EMAHook
. Defaults to 0. - **kwargs – Keyword arguments passed to subclasses of
BaseAveragedModel
after_load_checkpoint(runner, checkpoint)[source]¶
Resume ema parameters from checkpoint.
Parameters:
Return type:
None
after_test_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after test.
Parameters:
- runner (Runner) – The runner of the testing process.
- metrics (Dict [_str,_ float] , optional) – Evaluation results of all metrics on test dataset. The keys are the names of the metrics, and the values are corresponding results.
Return type:
None
after_train_iter(runner, batch_idx, data_batch=None, outputs=None)[source]¶
Update ema parameter.
Parameters:
- runner (Runner) – The runner of the training process.
- batch_idx (int) – The index of the current batch in the train loop.
- data_batch (Sequence _[_dict] , optional) – Data from dataloader. Defaults to None.
- outputs (dict, optional) – Outputs from model. Defaults to None.
Return type:
None
after_val_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after validation.
Parameters:
- runner (Runner) – The runner of the validation process.
- metrics (Dict [_str,_ float] , optional) – Evaluation results of all metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results.
Return type:
None
Create an ema copy of the model.
Parameters:
runner (Runner) – The runner of the training process.
Return type:
None
before_save_checkpoint(runner, checkpoint)[source]¶
Save ema parameters to checkpoint.
Parameters:
Return type:
None
before_test_epoch(runner)[source]¶
We load parameter values from ema model to source model before test.
Parameters:
runner (Runner) – The runner of the training process.
Return type:
None
Check the begin_epoch/iter is smaller than max_epochs/iters.
Parameters:
runner (Runner) – The runner of the training process.
Return type:
None
before_val_epoch(runner)[source]¶
We load parameter values from ema model to source model before validation.
Parameters:
runner (Runner) – The runner of the training process.
Return type:
None