MMOCR的hook定义在
/home/xhhao/anaconda3/envs/open-mmocr/lib/python3.7/site-packages/mmcv/runner/hooks/hook.py
这只是hook基类
具体用的是哪个hook是在mmocr/apis/train.py
这是train的时候用的DistSamplerSeedHook
这是val的时候用的DistEvalHook
Hook机制规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。
HOOK基类的定义
from mmcv.utils import Registry HOOKS = Registry('hook') class Hook: def before_run(self, runner): pass def after_run(self, runner): pass def before_epoch(self, runner): pass def after_epoch(self, runner): pass def before_iter(self, runner): pass def after_iter(self, runner): pass def before_train_epoch(self, runner): self.before_epoch(runner) def before_val_epoch(self, runner): self.before_epoch(runner) def after_train_epoch(self, runner): self.after_epoch(runner) def after_val_epoch(self, runner): self.after_epoch(runner) def before_train_iter(self, runner): self.before_iter(runner) def before_val_iter(self, runner): self.before_iter(runner) def after_train_iter(self, runner): self.after_iter(runner) def after_val_iter(self, runner): self.after_iter(runner) def every_n_epochs(self, runner, n): return (runner.epoch + 1) % n == 0 if n > 0 else False def every_n_inner_iters(self, runner, n): return (runner.inner_iter + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n): return (runner.iter + 1) % n == 0 if n > 0 else False def end_of_epoch(self, runner): return runner.inner_iter + 1 == len(runner.data_loader)
在baserunner类中有register_hook函数,还有很多地方有
hook函数是有多种类型的
hook优先级
import sys class HOOK: def before_breakfast(self, runner): print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name)) def after_breakfast(self, runner): print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name)) def before_lunch(self, runner): print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name)) def after_lunch(self, runner): print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name)) def before_dinner(self, runner): print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name)) def after_dinner(self, runner): print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name)) def after_finish_work(self, runner, are_you_busy=False): if are_you_busy: print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name)) else: print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name)) class Runner(object): def __init__(self, ): pass self._hooks = [] def register_hook(self, hook): # 这里不做优先级判断,直接在头部插入HOOK self._hooks.insert(0, hook) #将hook这个类插入到self._hook list中的第0个位置 def call_hook(self, hook_name): for hook in self._hooks: getattr(hook, hook_name)(self) def run(self): print('开始启动我的一天') self.call_hook('before_breakfast') self.call_hook('after_breakfast') self.call_hook('before_lunch') self.call_hook('after_lunch') self.call_hook('before_dinner') self.call_hook('after_dinner') self.call_hook('after_finish_work') print('~~睡觉~~') runner = Runner() hook = HOOK() runner.register_hook(hook) runner.run()
runner中用到哪个hook要在main函数中给它注册