ExponentialMovingAverage — mmengine 0.10.7 documentation (original) (raw)
class mmengine.model.ExponentialMovingAverage(model, momentum=0.0002, interval=1, device=None, update_buffers=False)[source]¶
Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
\[Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t\]
Note
This momentum
argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically,\(Xema_{t+1}\) is the moving average and \(X_t\) is the new observed value. The value of momentum is usually a small number, allowing observed values to slowly update the ema parameters.
Parameters:
- model (nn.Module) – The model to be averaged.
- momentum (float) – The momentum used for updating ema parameter. Defaults to 0.0002. Ema’s parameter are updated with the formula\(averaged\_param = (1-momentum) * averaged\_param + momentum * source\_param\).
- interval (int) – Interval between two updates. Defaults to 1.
- device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None. - update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
avg_func(averaged_param, source_param, steps)[source]¶
Compute the moving average of the parameters using exponential moving average.
Parameters:
- averaged_param (Tensor) – The averaged parameters.
- source_param (Tensor) – The source parameters.
- steps (int) – The number of times the parameters have been updated.
Return type:
None