(十三)mmdetection源码解读:MMCV 核心组件 HOOKS(二)

一、HOOKS实例化位置

对于新手来说,总是找不到注册器类的实例化位置,下面列出了通过import操作,实例化HOOKS的整个过程

# train_ KittiTiny.py
from mmdet.datasets.builder import DATASETS
# /mmdet/datasets/builder.py
from mmcv.runner import get_dist_info
# /mmcv/runner/__init__.py
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
                    DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
                    Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
                    GradientCumulativeOptimizerHook, Hook, IterTimerHook,
                    LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
                    OptimizerHook, PaviLoggerHook, SyncBuffersHook,
                    TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
# /mmcv/runner/hookshook.py
HOOKS = Registry('hook')

二、HOOKS注册位置

实例化注册表HOOKS之后,然后通过python装饰器函数,将类名和类添加到HOOKS._module_dict中,完成注册表的注册。
注册表HOOKS是在什么时候完成注册的呢,
from .hooks import HOOKS, Hook->实例化注册表HOOKS
from *** import ***->完成注册表HOOKS的注册

from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
                     NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
                     TextLoggerHook, WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
                         CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
                         ExpLrUpdaterHook, FixedLrUpdaterHook,
                         FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook,
                         LrUpdaterHook, OneCycleLrUpdaterHook,
                         PolyLrUpdaterHook, StepLrUpdaterHook)
from .memory import EmptyCacheHook
from .momentum_updater import (CosineAnnealingMomentumUpdaterHook,
                               CyclicMomentumUpdaterHook, MomentumUpdaterHook,
                               OneCycleMomentumUpdaterHook,
                               StepMomentumUpdaterHook)
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
                        GradientCumulativeOptimizerHook, OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook

__all__ = [
    'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
    'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
    'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
    'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
    'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook',
    'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
    'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
    'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook',
    'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook',
    'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
    'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
    'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
    'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook'
]

三、Hook和他的子类

基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,HOOK中每一个函数都是有runner作为参数传入的。在每一个hook函数中,都可以对runner进行必要的操作。

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, is_method_overridden

HOOKS = Registry('hook')


class Hook:
    stages = ('before_run', 'before_train_epoch', 'before_train_iter',
              'after_train_iter', 'after_train_epoch', 'before_val_epoch',
              'before_val_iter', 'after_val_iter', 'after_val_epoch',
              'after_run')

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

    def is_last_epoch(self, runner):
        return runner.epoch + 1 == runner._max_epochs

    def is_last_iter(self, runner):
        return runner.iter + 1 == runner._max_iters

    def get_triggered_stages(self):
        trigger_stages = set()
        for stage in Hook.stages:
            if is_method_overridden(stage, Hook, self):
                trigger_stages.add(stage)

        # some methods will be triggered in multi stages
        # use this dict to map method to stages.
        method_stages_map = {
            'before_epoch': ['before_train_epoch', 'before_val_epoch'],
            'after_epoch': ['after_train_epoch', 'after_val_epoch'],
            'before_iter': ['before_train_iter', 'before_val_iter'],
            'after_iter': ['after_train_iter', 'after_val_iter'],
        }

        for method, map_stages in method_stages_map.items():
            if is_method_overridden(method, Hook, self):
                trigger_stages.update(map_stages)

        return [stage for stage in Hook.stages if stage in trigger_stages]

Hook的子类
hooks/init.py
以下代码在’/mmcv/runner/hooks/…'中

#hook.py
HOOKS = Registry('hook')
class Hook:
    stages = ('before_run', 'before_train_epoch', 'before_train_iter',
              'after_train_iter', 'after_train_epoch', 'before_val_epoch',
              'before_val_iter', 'after_val_iter', 'after_val_epoch',
              'after_run')
#checkpoint.py
@HOOKS.register_module()
class CheckpointHook(Hook):
#closure.py
@HOOKS.register_module()
class ClosureHook(Hook):
#ema.py
@HOOKS.register_module()
class EMAHook(Hook):
#dvclive.py
@HOOKS.register_module()
class DvcliveLoggerHook(LoggerHook):
#mlflow.py
@HOOKS.register_module()
class MlflowLoggerHook(LoggerHook):
#neptune.py
@HOOKS.register_module()
class NeptuneLoggerHook(LoggerHook):
#pavi.py
@HOOKS.register_module()
class PaviLoggerHook(LoggerHook):
#tensorboard.py
@HOOKS.register_module()
class TensorboardLoggerHook(LoggerHook):
#text.py
@HOOKS.register_module()
class TextLoggerHook(LoggerHook):
#wandb.py
@HOOKS.register_module()
class WandbLoggerHook(LoggerHook):
#iter_timer.py
@HOOKS.register_module()
class IterTimerHook(Hook):
@HOOKS.register_module()
class FixedLrUpdaterHook(LrUpdaterHook):
class StepLrUpdaterHook(LrUpdaterHook):
class ExpLrUpdaterHook(LrUpdaterHook):
class PolyLrUpdaterHook(LrUpdaterHook):
class InvLrUpdaterHook(LrUpdaterHook):
class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
class CosineRestartLrUpdaterHook(LrUpdaterHook):
class CyclicLrUpdaterHook(LrUpdaterHook):
class OneCycleLrUpdaterHook(LrUpdaterHook):
#memory.py
@HOOKS.register_module()
class EmptyCacheHook(Hook):
#momentum_updater.py

@HOOKS.register_module()
class StepMomentumUpdaterHook(MomentumUpdaterHook):
class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
#optimizer.py
@HOOKS.register_module()
class OptimizerHook(Hook):
class GradientCumulativeOptimizerHook(OptimizerHook):
    class Fp16OptimizerHook(OptimizerHook):
    class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
                                              Fp16OptimizerHook):
#profiler.py
@HOOKS.register_module()
class ProfilerHook(Hook):
#sampler_seed.py
@HOOKS.register_module()
class DistSamplerSeedHook(Hook):
#sync_buffer.py
@HOOKS.register_module()
class SyncBuffersHook(Hook):
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值