最近做了一段时间的目标检测,不得不说检测这块还是相对比较复杂的,在熟悉项目的同时也确实学习到了很多有用的东西。MMdetetion是现在最著名、算法包最多并且使用人数最多的训练框架,其中的源码非常值得学习,今天总结下我对其中HOOK(钩子)机制的理解。
MMdetection最近更新很多,我以2.4.0版本的代码进行解读,分享自己的理解,也吸纳观众的点评。HOOK、Runer的定义在MMCV当中,MMdetection和MMCV是版本匹配的,我这里使用的是MMCV 1.1.2的代码。(HOOK相关的定义主要在MMCV中,下面用的代码都是摘自于MMCV)。
1.HOOK机制的作用
MMdetection中的HOOK可以理解为一种触发器,也可以理解为一种训练框架的架构规范,它规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。
首先看一下HOOK的基类定义
# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
可以说基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,可以看到HOOK中每一个参数都是有runner作为参数传入的。关于