钩子Hook
参照MMDetection官网
mmdetection的钩子分为内置钩子与自定义钩子,内置钩子默认注册,自定义钩子需要自己注册。
优先级
每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。
优先级列表如下:
HIGHEST (0)
VERY_HIGH (10)
HIGH (30)
ABOVE_NORMAL (40)
NORMAL (50)
BELOW_NORMAL (60)
LOW (70)
VERY_LOW (90)
LOWEST (100)
内置钩子
1.CheckpointHook
自定义权重保存的间隔,如果是分布式多卡训练,则只有主(master)进程会保存权重,支持按 epoch 数或者 iteration 数保存权重
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))
# by_epoch 的默认值为 True,以 epoch作为保存间隔
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
# 以迭代次数 iteration作为保存间隔
其他参数:
max_keep_ckpts=10 # 只保存最新10个权重
save_best='auto' # 只保存最优权重,'auto'根据验证集的第一个评价指标判断权重最优
save_best='accuracy', rule='greater'
out_dir='/path/of/directory' # 指定保存权重的路径
2.LoggerHook
LoggerHook负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端,每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数如下:
default_hooks = dict(logger=dict(type='LoggerHook', interval=20))
3.ParamSchedulerHook
ParamSchedulerHook 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。ParamSchedulerHook默认注册到执行器并且没有可配置的参数,无需对其做任何配置。
4.IterTimerHook
IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
5.DistSamplerSeedHook
DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
6.RuntimeInfoHook
RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
7.EMAHook
EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。
custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverage 和MomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。
custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]
8.EmptyCacheHook
EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,
# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
before_epoch # (bool) Defaults to False
after_epoch # (bool) Defaults to True
after_iter # (bool) Defaults to False
9.SyncBuffersHook
SyncBuffersHook 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 running_mean 以及 running_var。
custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
自定义钩子
通过继承基类并重写相应位点方法,所有的hook都储存在self._hook列表中,其顺序和优先级挂钩
例:在每次迭代后判断损失值是否无穷大,重写 after_train_iter 位点
import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
def __init__(self, interval=50):
self.interval = interval # 钩子执行的间隔,每iterations执行一次钩子
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']),\
runner.logger.info('loss become infinite or NaN!')
将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子,
from mmengine.runner import Runner
custom_hooks = [
dict(type='CheckInvalidLossHook', interval=50)
]
runner = Runner(custom_hooks=custom_hooks, ...) # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train() # 执行器开始训练
改变钩子hook的优先级,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段
custom_hooks = [
dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
]
也可以在定义类时给定优先级
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
priority = 'ABOVE_NORMAL'
自定义网络结构
(正在写)