详解MMdetectionHOOK机制

编辑 | 古月居

点击下方卡片,关注“自动驾驶之心”公众号

ADAS巨卷干货,即可获取

最近做了一段时间的目标检测,不得不说检测这块还是相对比较复杂的,在熟悉项目的同时也确实学习到了很多有用的东西。

MMdetetion是现在最著名、算法包最多并且使用人数最多的训练框架,其中的源码非常值得学习,今天总结下我对其中HOOK(钩子)机制的理解。

MMdetection最近更新很多,我以2.4.0版本的代码进行解读,分享自己的理解,也吸纳观众的点评。

HOOK、Runer的定义在MMCV当中,MMdetection和MMCV是版本匹配的,我这里使用的是MMCV 1.1.2的代码。(HOOK相关的定义主要在MMCV中,下面用的代码都是摘自于MMCV)。

1.HOOK机制的作用

MMdetection中的HOOK可以理解为一种触发器,也可以理解为一种训练框架的架构规范,它规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。

首先看一下HOOK的基类定义

# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry


HOOKS = Registry('hook')




class Hook:


    def before_run(self, runner):
        pass


    def after_run(self, runner):
        pass


    def before_epoch(self, runner):
        pass


    def after_epoch(self, runner):
        pass


    def before_iter(self, runner):
        pass


    def after_iter(self, runner):
        pass


    def before_train_epoch(self, runner):
        self.before_epoch(runner)


    def before_val_epoch(self, runner):
        self.before_epoch(runner)


    def after_train_epoch(self, runner):
        self.after_epoch(runner)


    def after_val_epoch(self, runner):
        self.after_epoch(runner)


    def before_train_iter(self, runner):
        self.before_iter(runner)


    def before_val_iter(self, runner):
        self.before_iter(runner)


    def after_train_iter(self, runner):
        self.after_iter(runner)


    def after_val_iter(self, runner):
        self.after_iter(runner)


    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False


    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False


    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False


    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

 可以说基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,可以看到HOOK中每一个参数都是有runner作为参数传入的。

关于Runner的作用下一篇文章接着说,简而言之,Runner是一个模型训练的工厂,在其中我们可以加载数据、训练、验证以及梯度backward等等全套流程。

MMdetection在设计的时候也为runner传入丰富的参数,定义了一个非常好的训练范式。在你的每一个hook函数中,都可以对runner进行你想要的操作。

而HOOK是怎么嵌套进runner中的呢?其实是在Runner中定义了一个hook的list,list中的每一个元素就是一个实例化的HOOK对象。

其中提供了两种注册hook的方法,register_hook是传入一个实例化的HOOK对象,并将它插入到一个列表中,register_hook_from_cfg是传入一个配置项,根据配置项来实例化HOOK对象并插入到列表中。

当然第二种方法又是MMLab的开源生态中定义的一种基础方法mmcv.build_from_cfg了,无论在MMdetection还是其他MMLab开源的算法框架中,都遵循着MMCV的这套基于配置项实例化对象的方法。

毕竟MMCV是提供了一个基础的功能,服务于各个算法框架,这也是为什么MMLab的代码高质量的原因。不仅仅是算法的复现,更是架构、编程范式的一种体现,真·代码如诗。

def register_hook(self, hook, priority='NORMAL'):
        """Register a hook into the hook list.
        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.
        Args:
            hook (:obj:`Hook`): The hook to be registered.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        # hook是分优先级插入到list中的,在MMdetection中不同的HOOK是有优先级的,为什么呢?稍后在hook的调用中解释哈
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)


    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.
        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.
        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

调用HOOK函数

def call_hook(self, fn_name):
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

可以看到HOOK是调用的时候是遍历List,然后根据HOOK的名字来调用。这也是为什么要区分优先级的原因,优先级越高的放在List的前面,这样就能更快地被调用。

当你想用before_run_epoch来做A和B两件事情的时候,在runner里面就是调用一次self.before_run_epoch,但是先做A还是先做B,就是通过不同的HOOK的优先级来决定了。

比如在evaluation的时候对需要做测试,但是测试前对参数做滑动平均。比如emaHOOK中的72行,也写明了要在测试之前做指数滑动平均。

def after_train_epoch(self, runner):
        """We load parameter values from ema backup to model before the
        EvalHook."""
        self._swap_ema_parameters()

checkpoint.py的HOOK中,同样也定义了after_train_epoch函数如下:

@master_only
    def after_train_epoch(self, runner):
        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
            return


        runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
        if not self.out_dir:
            self.out_dir = runner.work_dir
        runner.save_checkpoint(
            self.out_dir, save_optimizer=self.save_optimizer, **self.args)


        # remove other checkpoints
        if self.max_keep_ckpts > 0:
            filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
            current_epoch = runner.epoch + 1
            for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
                ckpt_path = os.path.join(self.out_dir,
                                         filename_tmpl.format(epoch))
                if os.path.exists(ckpt_path):
                    os.remove(ckpt_path)
                else:
                    break

从测试代码中可以看到不同的HOOK虽然都是重写了after_train_epoch函数,但是调用的顺序还是先调用ema.py中的,然后再调用checkpoint.py中的after_train_epoch。

resume_ema_hook = EMAHook(
        momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
    runner = _build_demo_runner()
    runner.model = demo_model
    # 设置了HIGHREST的优先级
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 2)

具体的优先级定义有以下7种,作为HOOK的类成员属性。具体定义在链接中。

https://link.zhihu.com/?target=https%3A//github.com/open-mmlab/mmcv/blob/eb65c21da219e79d2fbc27dd056e94991b8718a8/mmcv/runner/priority.py

+------------+------------+
    | Level      | Value      |
    +============+============+
    | HIGHEST    | 0          |
    +------------+------------+
    | VERY_HIGH  | 10         |
    +------------+------------+
    | HIGH       | 30         |
    +------------+------------+
    | NORMAL     | 50         |
    +------------+------------+
    | LOW        | 70         |
    +------------+------------+
    | VERY_LOW   | 90         |
    +------------+------------+
    | LOWEST     | 100        |
    +------------+------------+

2.举一个简单的例子

最近打算好好锻炼身体,健康生活,努力工作,我打算让自己变得更加自律。

我给自己定下了几个条例,每天吃早饭之前得晨练30分钟,运动完之后才会感觉充满活力。

每天吃午饭之前我得跑上一个实验,吃完饭之后回来刚好可以看下中间结果,吃完午饭之后我感觉结果没问题我需要午休30分钟, 晚上下班前我如果没什么事再锻炼30分钟。

秉承着这样的原则我给自己定义一个HOOK来规范我的生活。

定义我的HOOK

import sys
class HOOK:


    def before_breakfirst(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))


    def after_breakfirst(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))


    def before_lunch(self, runner):
        print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name))


    def after_lunch(self, runner):
        print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name))


    def before_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))


    def after_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))


    def after_finish_work(self, runner, are_you_busy=False):
        if are_you_busy:
            print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name))
        else:
            print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name))

定义我的Runner

class Runner(object):
    def __init__(self, ):
        pass
        self._hooks = []


    def register_hook(self, hook):
        # 这里不做优先级判断,直接在头部插入HOOK
        self._hooks.insert(0, hook)


    def call_hook(self, hook_name):
        for hook in self._hooks:
            getattr(hook, hook_name)(self)


    def run(self):
        print('开始启动我的一天')
        self.call_hook('before_breakfirst')
        self.call_hook('after_breakfirst')
        self.call_hook('before_lunch')
        self.call_hook('after_lunch')
        self.call_hook('before_dinner')
        self.call_hook('after_dinner')
        self.call_hook('after_finish_work')
        print('~~睡觉~~')

运行main函数,注册HOOK并且调用Runner.run()开启我的一天

from MyHook import HOOK
from MyRunner import Runner
runner = Runner()
hook = HOOK()
runner.register_hook(hook)
runner.run()

得到的输出结果如下:

开始启动我的一天
before_breakfirst:吃早饭之前晨练30分钟
after_breakfirst:吃早饭之前晨练30分钟
before_lunch:吃午饭之前跑上实验
after_lunch:吃完午饭午休30分钟
before_dinner: 没想好做什么
after_dinner: 没想好做什么
after_finish_work:今天没啥事,去锻炼30分钟
~~睡觉~~

3.总结

MMdetection中的HOOK设计巧妙,很好地对算法训练、测试进行了抽象和解耦。

每一个做上层算法模型的,都值得一看。感谢MMLab贡献这么优质的代码,让我等凡夫俗子醍醐灌顶。

除了HOOK之外,这个代码中还有很多优质的思想。比如Runner是怎么做到包办一切的?注册器这个中枢管理系统是怎么工作的?多卡训练的一些坑是怎么解决的?等等等等,我也在持续地学习和消化。路漫漫其修远兮,吾将上下而求索。

一个小题目:我的代码中每个函数输出的时候都会打印出这个函数名,这个可以用装饰器很方便地解决奥。装饰器这个东西在MMLab的系列项目中有大量的应用。

其中对fp16的支持让大家赞不绝口。接下来有时间,对Runner、Register、装饰器这些东西好好盘一盘。

国内首个自动驾驶学习社区

近1000人的交流社区,和20+自动驾驶技术栈学习路线,想要了解更多自动驾驶感知(分类、检测、分割、关键点、车道线、3D目标检测、多传感器融合、目标跟踪、光流估计、轨迹预测)、自动驾驶定位建图(SLAM、高精地图)、自动驾驶规划控制、领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频,期待交流!

d135576f86f9391b760ceb8544bfac34.jpeg

自动驾驶之心】全栈技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D目标检测、BEV感知、多传感器融合、SLAM、光流估计、深度估计、轨迹预测、高精地图、NeRF、规划控制、模型部署落地、自动驾驶仿真测试、产品经理、硬件配置、AI求职交流等方向;

a478f42ccaec82f5c6e6ec9b04df2c8d.jpeg

添加汽车人助理微信邀请入群

备注:学校/公司+方向+昵称

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值