mmdetection源码阅读笔记(0)--创建模型

之前做天池比赛用mmdetection取得了还不错的成绩,就想仔细读读mmdetection的源码,了解下具体实现。

这个系列,准备按照目标检测和实例分割的pipeline来写。


训练脚本

官方提供了分布式训练,并且推荐使用分布式训练,即使在单机器上dist_train.sh

#!/usr/bin/env bash

PYTHON=${PYTHON:-"python3"}

$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}

该脚本主要使用了torch.distributed.launch辅助启动工具,这个工具可以辅助在每个节点上启动多个进程process,支持Python2 和 Python3.
更多关于分布式训练的细节可以参考pytorch 分布式训练 distributed parallel 笔记


创建模型

train.pymain()函数,先做了一些config文件,work_dir以及log的操作,之后调用了build_detector()来创建模型。

build_detector()

build_detector()定义在mmdet/models/builder.py中。
下面是主要用到的几个函数。
mmdet/models/builder.py

from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS

def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
    

build_detector()中有一个DETECTORS这是一个注册器,里面保存了所有支持的detector。具体的实现方式和Python装饰器有点像。
下面以cascade_rcnn为例,看下是怎么进行注册过来的。

  1. 首先在mmdet/models/__init__.py里面from .detectors import *
  2. mmdet/models/detectors/__init__.py里面from .cascade_rcnn import CascadeRCNN
  3. mmdet/models/detectors/cascade_rcnn.py
from ..registry import DETECTORS
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
    other codes

@DETECTORS.register_module这一行代码,将CascadeRCNN注册到了DETECTORS中。
这里简单的说下@的用法,Python当解释器读到@的这样的修饰符之后,会先解析@后的内容,直接就把@下一行的函数或者类作为@后边的函数的参数,然后将返回值赋值给下一行修饰的函数对象。
例如:

def a():
    print("func a")
def b():
    print("func b")
@a
@b
def c():
    print("func c")

python会按照自下而上的顺序把各自的函数结果作为下一个函数(上面的函数)的输入,也就是a(b(c()))
回到我们的DETECTORS,也就是上面的操作将CascadeRCNN传给了DETECTORS.register_module
mmdet/models/registry.py

class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not issubclass(module_class, nn.Module):
            raise TypeError(
                'module must be a child of nn.Module, but got {}'.format(
                    module_class))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')

注册的模型被保存到了,self._module_dict中。
再回到builder.py
mmdet/models/builder.py

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return _build_module(cfg, registry, default_args)
        
def _build_module(cfg, registry, default_args):
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        if obj_type not in registry.module_dict:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
        obj_type = registry.module_dict[obj_type]
    elif not isinstance(obj_type, type):
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)

build()中主要通过_build_module()registry.module_dict中实例化注册过的模型。


最后

这篇主要讲了mmdetection中的创建模型,下一篇准备以Cascade Rcnn为例看下网络的具体搭建。

  • 12
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值