mmdetecton从精通到高阶系列,最全最详细

钩子Hook

参照MMDetection官网
mmdetection的钩子分为内置钩子与自定义钩子,内置钩子默认注册,自定义钩子需要自己注册。

优先级

每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。
优先级列表如下:

HIGHEST (0)

VERY_HIGH (10)

HIGH (30)

ABOVE_NORMAL (40)

NORMAL (50)

BELOW_NORMAL (60)

LOW (70)

VERY_LOW (90)

LOWEST (100)

内置钩子

1.CheckpointHook
自定义权重保存的间隔,如果是分布式多卡训练,则只有主(master)进程会保存权重,支持按 epoch 数或者 iteration 数保存权重

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True)) 
# by_epoch 的默认值为 True,以 epoch作为保存间隔
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))  
# 以迭代次数 iteration作为保存间隔

其他参数:

max_keep_ckpts=10  # 只保存最新10个权重 
save_best='auto'  # 只保存最优权重,'auto'根据验证集的第一个评价指标判断权重最优
save_best='accuracy', rule='greater'
out_dir='/path/of/directory'  # 指定保存权重的路径

2.LoggerHook
LoggerHook负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端,每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数如下:

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

3.ParamSchedulerHook
ParamSchedulerHook 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。ParamSchedulerHook默认注册到执行器并且没有可配置的参数,无需对其做任何配置。

4.IterTimerHook
IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

5.DistSamplerSeedHook
DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

6.RuntimeInfoHook
RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

7.EMAHook
EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverage 和MomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

8.EmptyCacheHook
EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,

# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
before_epoch  # (bool) Defaults to False
after_epoch  # (bool) Defaults to True
after_iter  # (bool) Defaults to False

9.SyncBuffersHook
SyncBuffersHook 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 running_mean 以及 running_var。

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

自定义钩子

通过继承基类并重写相应位点方法,所有的hook都储存在self._hook列表中,其顺序和优先级挂钩
例:在每次迭代后判断损失值是否无穷大,重写 after_train_iter 位点

import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    def __init__(self, interval=50):
        self.interval = interval  # 钩子执行的间隔,每iterations执行一次钩子
	
    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """
        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')

将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子,

from mmengine.runner import Runner

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50)
]
runner = Runner(custom_hooks=custom_hooks, ...)  # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train()  # 执行器开始训练

改变钩子hook的优先级,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
]

也可以在定义类时给定优先级

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    priority = 'ABOVE_NORMAL'

自定义网络结构

(正在写)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值