Hook — mmengine 0.10.7 documentation (original) (raw)

Hook programming is a programming pattern in which a mount point is set in one or more locations of a program. When the program runs to a mount point, all methods registered to it at runtime are automatically called. Hook programming can increase the flexibility and extensibility of the program, since users can register custom methods to the mount point to be called without modifying the code in the program.

Built-in Hooks

MMEngine encapsules many ultilities as built-in hooks. These hooks are divided into two categories, namely default hooks and custom hooks. The former refers to those registered with the Runner by default, while the latter refers to those registered by the user on demand.

Each hook has a corresponding priority. At each mount point, hooks with higher priority are called earlier by the Runner. When sharing the same priority, the hooks are called in their registration order. The priority list is as follows.

default hooks

Name Function Priority
RuntimeInfoHook update runtime information into message hub VERY_HIGH (10)
IterTimerHook Update the time spent during iteration into message hub NORMAL (50)
DistSamplerSeedHook Ensure distributed Sampler shuffle is active NORMAL (50)
LoggerHook Collect logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb .etc BELOW_NORMAL (60)
ParamSchedulerHook update some hyper-parameters of optimizer LOW (70)
CheckpointHook Save checkpoints periodically VERY_LOW (90)

custom hooks

Name Function Priority
EMAHook Apply Exponential Moving Average (EMA) on the model during training NORMAL (50)
EmptyCacheHook Releases all unoccupied cached GPU memory during the process of training NORMAL (50)
SyncBuffersHook Synchronize model buffers at the end of each epoch NORMAL (50)
ProfilerHook Analyze the execution time and GPU memory usage of model operators VERY_LOW (90)

Note

It is not recommended to modify the priority of the default hooks, as hooks with lower priority may depend on hooks with higher priority. For example, CheckpointHook needs to have a lower priority than ParamSchedulerHook so that the saved optimizer state is correct. Also, the priority of custom hooks defaults to NORMAL (50).

The two types of hooks are set differently in the Runner, with the configuration of default hooks being passed to the default_hooks parameter of the Runner and the configuration of custom hooks being passed to the custom_hooks parameter, as follows.

from mmengine.runner import Runner default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), timer=dict(type='IterTimerHook'), sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), ) custom_hooks = [dict(type='EmptyCacheHook')] runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...) runner.train()

LoggerHook

LoggerHook collects logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb, etc.

CheckpointHook

CheckpointHook saves the checkpoints at a given interval. In the case of distributed training, only the master process will save the checkpoints. The main features of CheckpointHook is as follows.

For more features, please read the CheckpointHook API documentation.

The six features mentioned above are described below.

the default value of by_epoch is True

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))
If you want to save checkpoints by iteration, you can set by_epoch to False and interval=5 to save them every 5 iterations.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))

LoggerHook collects logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the interval parameter and configure it as follows.

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

If you are interested in how MMEngine manages logging, you can refer to logging.

ParamSchedulerHook

ParamSchedulerHook iterates through all optimizer parameter schedulers of the Runner and calls their step method to update the optimizer parameters in order. See Parameter Schedulers for more details about what are parameter schedulers.

ParamSchedulerHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.

IterTimerHook

IterTimerHook is used to record the time taken to load data and iterate once.

IterTimerHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.

DistSamplerSeedHook

DistSamplerSeedHook calls the step method of the Sampler during distributed training to ensure that the shuffle operation takes effect.

DistSamplerSeedHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.

RuntimeInfoHook

RuntimeInfoHook will update the current runtime information (e.g. epoch, iter, max_epochs, max_iters, lr, metrics, etc.) to the message hub at different mount points in the Runner so that other modules without access to the Runner can obtain this information.

RuntimeInfoHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.

EMAHook

EMAHook performs an exponential moving average operation on the model during training, with the aim of improving the robustness of the model. Note that the model generated by exponential moving average is only used for validation and testing, and does not affect training.

custom_hooks = [dict(type='EMAHook')] runner = Runner(custom_hooks=custom_hooks, ...) runner.train()

EMAHook uses ExponentialMovingAverage by default, with optional values of StochasticWeightAverage and MomentumAnnealingEMA. Other averaging strategies can be used by setting ema_type.

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

See EMAHook API Reference for more usage.

EmptyCacheHook

EmptyCacheHook calls torch.cuda.empty_cache() to release all unoccupied cached GPU memory. The timing of releasing memory can be controlled by setting parameters like before_epoch, after_iter, and after_epoch, meaning before the start of each epoch, after each iteration, and after each epoch respectively.

The release operation is performed at the end of each epoch

custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)] runner = Runner(custom_hooks=custom_hooks, ...) runner.train()

SyncBuffersHook

SyncBuffersHook synchronizes the buffer of the model at the end of each epoch during distributed training, e.g. running_mean and running_var of the BN layer.

custom_hooks = [dict(type='SyncBuffersHook')] runner = Runner(custom_hooks=custom_hooks, ...) runner.train()

ProfilerHook

The ProfilerHook is used to analyze the execution time and GPU memory occupancy of model operators.

custom_hooks = [dict(type='ProfilerHook', on_trace_ready=dict(type='tb_trace'))] runner = Runner(custom_hooks=custom_hooks, ...) runner.train()

The profiling results will be saved in the tf_tracing_logs directory under work_dirs/{timestamp}, and can be visualized using TensorBoard with the command tensorboard --logdir work_dirs/{timestamp}/tf_tracing_logs.

For more information on the usage of the ProfilerHook, please refer to the ProfilerHook documentation.

Customize Your Hooks

If the built-in hooks provided by MMEngine do not cover your demands, you are encouraged to customize your own hooks by simply inheriting the base hook class and overriding the corresponding mount point methods.

For example, if you want to check whether the loss value is valid, i.e. not infinite, during training, you can simply override the after_train_iter method as below. The check will be performed after each training iteration.

import torch from mmengine.registry import HOOKS from mmengine.hooks import Hook @HOOKS.register_module() class CheckInvalidLossHook(Hook): """Check invalid loss hook. This hook will regularly check whether the loss is valid during training. Args: interval (int): Checking interval (every k iterations). Defaults to 50. """ def init(self, interval=50): self.interval = interval def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): """All subclasses should override this method, if they need any operations after each training iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. """ if self.every_n_train_iters(runner, self.interval): assert torch.isfinite(outputs['loss']),
runner.logger.info('loss become infinite or NaN!')

We simply pass the hook config to the custom_hooks parameter of the Runner, which will register the hooks when the Runner is initialized.

from mmengine.runner import Runner custom_hooks = [ dict(type='CheckInvalidLossHook', interval=50) ] runner = Runner(custom_hooks=custom_hooks, ...) runner.train() # start training

Then the loss value are checked after iteration.

Note that the priority of the custom hook is NORMAL (50) by default, if you want to change the priority of the hook, then you can set the priority key in the config.

custom_hooks = [ dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL') ]

You can also set priority when defining classes.

@HOOKS.register_module() class CheckInvalidLossHook(Hook): priority = 'ABOVE_NORMAL'