(五)mmdetection源码解读:何时注册HOOKS、MODELS、DATASETS、PIPELINES

我们在阅读mmdetection源代码的时候发现,很多文件路径下包含__init__.py文件

                                     

我们通常导入包的时候一般都是import xxx.xxx,或者from xxx.xxx import xxx,如果想批量导入,一般使用__init__.py文件。在__init__.py文件中,有一个很重要的变量__all__,只要我们配置了 __all__,就可以在其他模块中通过from 文件夹名称 import * 将配置在__all__列表中的所有模块一次性导入进来。

1、注册HOOKS

下面是hook/__init__.py源代码。

# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
                     NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
                     TextLoggerHook, WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
                         CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
                         ExpLrUpdaterHook, FixedLrUpdaterHook,
                         FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook,
                         LrUpdaterHook, OneCycleLrUpdaterHook,
                         PolyLrUpdaterHook, StepLrUpdaterHook)
from .memory import EmptyCacheHook
from .momentum_updater import (CosineAnnealingMomentumUpdaterHook,
                               CyclicMomentumUpdaterHook, MomentumUpdaterHook,
                               OneCycleMomentumUpdaterHook,
                               StepMomentumUpdaterHook)
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
                        GradientCumulativeOptimizerHook, OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook

__all__ = [
    'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
    'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
    'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
    'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
    'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook',
    'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
    'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
    'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook',
    'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook',
    'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
    'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
    'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
    'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook'
]

那么这个代码在什么时候调用呢,我们debug一下mmdetection的源代码,具体的执行过程如下: 

#自己编写的用来训练的.py文件
from mmdet.datasets.builder import DATASETS
#mmdet/datasets/builder.py
#这个地方就需要注意了,mmcv.runner是个文件夹,这里被当作包来引用,也就是会执行mmcv/runner/__init__.py
from mmcv.runner import get_dist_info
#mmcv/runner/__init__.py
from .base_runner import BaseRunner
#mmcv/runner/base_runner.py
#这个地方也需要注意了,.hooks是个文件夹,这里被当作包来引用,也就是会执行
mmcv/runner/hooks/__init__.py,上面的代码。
from .hooks import HOOKS, Hook
#mmcv/runner/hooks/__init__.py代码中的每个引用都会执行内部的装饰器函数,也就是把对应的类名和类添加到HOOKS._module_dict中。
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
#...

注意:.checkpoint,同一个目录引用

2、注册MODELS

from mmdet.datasets.builder import DATASETS

from mmdet.datasets import DATASETS

这两个有区别吗?

我的理解:第一个是从模块引入,第二个是从包引入,也就是第二个会执行mmdet/datasets/__init__.py,但实际上第一种写法也执行了mmdet/datasets/__init__.py,和的理解有出入。是不是可以理解为mmdet.datasets.builder是mmdet.datasets包里的模块?

#自己编写的用来训练的.py文件
from mmdet.datasets.builder import DATASETS
#mmdet/datasets/__init__.py
from .utils import (NumClassCheckHook, get_loading_pipeline,
                    replace_ImageToTensor)

#mmdet/datasets/utils.py
from mmdet.models.dense_heads import GARPNHead, RPNHead

#mmdet/models/__init__.py

from .backbones import *  # noqa: F401,F403
from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                      ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
                      build_detector, build_head, build_loss, build_neck,
                      build_roi_extractor, build_shared_head)
from .dense_heads import *  # noqa: F401,F403
from .detectors import *  # noqa: F401,F403
from .losses import *  # noqa: F401,F403
from .necks import *  # noqa: F401,F403
from .plugins import *  # noqa: F401,F403
from .roi_heads import *  # noqa: F401,F403
from .seg_heads import *  # noqa: F401,F403

__all__ = [
    'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
    'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
    'build_shared_head', 'build_head', 'build_loss', 'build_detector'
]

 3、注册DATASETS

#自己编写的用来训练的.py文件
from mmdet.datasets.builder import DATASETS
#key-value添加到self._module_dict[name] = module_class
@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):

下面来综合分析一下这些注册器类:DATASETS 、PIPELINES 、MODELS、HOOKS

#mmdet/datasets/builder.py
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
#mmdet/models/builder.py
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
#mmcv/runner/hooks/hook.py
HOOKS = Registry('hook')

注意上面的这些语法,只是实例化了注册器类,下面的这段语法则是插入对应的类名和类:

 @DATASETS.register_module()
 @PIPELINES.register_module()
 @MODELS.register_module()
 @HOOKS.register_module()

下面代码中这些注册器类中的module_class,则是通过mmcv/utils/registry.py中的build_from_cfg函数,利用配置文件信息,来实例化。

self._module_dict[name] = module_class
总结:用HOOKS说明一下主要流程

1、实例化注册器类:HOOKS = Registry('hook')

2、用装饰器:@HOOKS.register_module()修饰对应的类,并在__init__.py中import,这样才会在import时执行装饰器,插入类名和类self._module_dict[name] = module_class

3、执行相关函数,根据配置文件实例化对应的module_class。比如:runner.register_training_hooks

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值