一、HOOKS实例化位置
对于新手来说,总是找不到注册器类的实例化位置,下面列出了通过import操作,实例化HOOKS的整个过程
# train_ KittiTiny.py
from mmdet.datasets.builder import DATASETS
# /mmdet/datasets/builder.py
from mmcv.runner import get_dist_info
# /mmcv/runner/__init__.py
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SyncBuffersHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
# /mmcv/runner/hookshook.py
HOOKS = Registry('hook')
二、HOOKS注册位置
实例化注册表HOOKS之后,然后通过python装饰器函数,将类名和类添加到HOOKS._module_dict中,完成注册表的注册。
注册表HOOKS是在什么时候完成注册的呢,
from .hooks import HOOKS, Hook->实例化注册表HOOKS
from *** import ***->完成注册表HOOKS的注册
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook,
LrUpdaterHook, OneCycleLrUpdaterHook,
PolyLrUpdaterHook, StepLrUpdaterHook)
from .memory import EmptyCacheHook
from .momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook, MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook',
'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook',
'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook'
]
三、Hook和他的子类
基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,HOOK中每一个函数都是有runner作为参数传入的。在每一个hook函数中,都可以对runner进行必要的操作。
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, is_method_overridden
HOOKS = Registry('hook')
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return [stage for stage in Hook.stages if stage in trigger_stages]
Hook的子类
hooks/init.py
以下代码在’/mmcv/runner/hooks/…'中
#hook.py
HOOKS = Registry('hook')
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
#checkpoint.py
@HOOKS.register_module()
class CheckpointHook(Hook):
#closure.py
@HOOKS.register_module()
class ClosureHook(Hook):
#ema.py
@HOOKS.register_module()
class EMAHook(Hook):
#dvclive.py
@HOOKS.register_module()
class DvcliveLoggerHook(LoggerHook):
#mlflow.py
@HOOKS.register_module()
class MlflowLoggerHook(LoggerHook):
#neptune.py
@HOOKS.register_module()
class NeptuneLoggerHook(LoggerHook):
#pavi.py
@HOOKS.register_module()
class PaviLoggerHook(LoggerHook):
#tensorboard.py
@HOOKS.register_module()
class TensorboardLoggerHook(LoggerHook):
#text.py
@HOOKS.register_module()
class TextLoggerHook(LoggerHook):
#wandb.py
@HOOKS.register_module()
class WandbLoggerHook(LoggerHook):
#iter_timer.py
@HOOKS.register_module()
class IterTimerHook(Hook):
@HOOKS.register_module()
class FixedLrUpdaterHook(LrUpdaterHook):
class StepLrUpdaterHook(LrUpdaterHook):
class ExpLrUpdaterHook(LrUpdaterHook):
class PolyLrUpdaterHook(LrUpdaterHook):
class InvLrUpdaterHook(LrUpdaterHook):
class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
class CosineRestartLrUpdaterHook(LrUpdaterHook):
class CyclicLrUpdaterHook(LrUpdaterHook):
class OneCycleLrUpdaterHook(LrUpdaterHook):
#memory.py
@HOOKS.register_module()
class EmptyCacheHook(Hook):
#momentum_updater.py
@HOOKS.register_module()
class StepMomentumUpdaterHook(MomentumUpdaterHook):
class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
#optimizer.py
@HOOKS.register_module()
class OptimizerHook(Hook):
class GradientCumulativeOptimizerHook(OptimizerHook):
class Fp16OptimizerHook(OptimizerHook):
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
Fp16OptimizerHook):
#profiler.py
@HOOKS.register_module()
class ProfilerHook(Hook):
#sampler_seed.py
@HOOKS.register_module()
class DistSamplerSeedHook(Hook):
#sync_buffer.py
@HOOKS.register_module()
class SyncBuffersHook(Hook):