商汤开源目标检测工具箱mmdetection代码详解(一)------ build和Registry和配置信息,分析mmedetection如何动态构建网络

mmdetection版本:2.0

一、注册表全局变量:

1.  DATASETS, BACKBONES, LOSSES, DETECTORS 等等的注册表全局变量是什么:

我们在看mmdetection源码时,首先肯定找主函数main()在哪。主函数在 tool/train.py 里。然后我们遇到第一个疑惑的就会是build_detector(),然后往下看,还会 有一堆类似的,例如 build_dataset(),build_backbone(),build_neck(),build_loss()等等。我们暂时称这些为“build系列”,这些“build系列”都是用来把对应的class注册到(放到)一个全局变量中(相当于一个注册表),如DATASETS(这 是一个囊括了所有mmdetection支持的数据集的全局变量,可以从下面看到,DATASETS里不仅仅有指定包括了什么数据集,而且还有这个数据集的类所在的路径):

#DATASETS
Registry(name=dataset, items={
'CustomDataset': <class'mmdet.datasets.custom.CustomDataset'>, 
'CocoDataset': <class'mmdet.datasets.coco.CocoDataset'>, 
'CityscapesDataset': <class'mmdet.datasets.cityscapes.CityscapesDataset'>, 
'ConcatDataset': <class'mmdet.datasets.dataset_wrappers.ConcatDataset'>, 
'RepeatDataset': <class'mmdet.datasets.dataset_wrappers.RepeatDataset'>, 
'XMLDataset': <class'mmdet.datasets.xml_style.XMLDataset'>, 
'VOCDataset': <class'mmdet.datasets.voc.VOCDataset'>, 
'WIDERFaceDataset': <class'mmdet.datasets.wider_face.WIDERFaceDataset'>}) 

我们再看一个放backbones的注册表BACKBONES:

我们可以看到mmdetection给我们提供了很多用于提取特征的backbone。

#BACKBONES
Registry(name=backbone, items={
'ResNet': <class 'mmdet.models.backbones.resnet.ResNet'>, 
'ResNetV1d': <class 'mmdet.models.backbones.resnet.ResNetV1d'>, 
'HRNet': <class 'mmdet.models.backbones.hrnet.HRNet'>, 
'ResNeXt': <class 'mmdet.models.backbones.resnext.ResNeXt'>, 
'SSDVGG': <class 'mmdet.models.backbones.ssd_vgg.SSDVGG'>}) 

2. 注册表全局变量的作用:

那这些全局变量有什么用呢?仅仅是给用户说明mmdetection支持什么数据集,什么loss,什么backbone吗,肯定不是。当我们看到每个全局变量里面装的都是class,就应该知道,这些全局变量是用来把某一个class实例化的。就如BACKBONES,具体要实例化ResNet还是ResNext,就要看使用者怎么传参。

3. 构建注册表全局变量:

绝大部分的注册表全局变量都是在文件 /mmdet/models/builder.py 中定义的。

#/mmdet/models/builder.py  部分代码
from mmcv.utils import Registry,build_from_cfg
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

我们可以看到,全部的注册表全局变量实质都是 Registry类的实例化对象,只是传入的参数不同而已。

然后每个注册表全局变量(如BACKBONES)都需要往里面注册各种不同的类(如ResNet,ResNext)。那怎么注册呢?注册的方式就是利用python 的修饰器功能,(修饰器的详情可以看:https://blog.csdn.net/u014453898/article/details/68937325),我简单地介绍一下修饰器:

如下面代码所示,Registry定义一个注册表类,DATASETS为Registry的一个实例化对象,用于记录支持的数据集类。

通过注解@DATASETS.register_module()来调用修饰方法 register_module()来把COCO类和VOC类放进DATASETS中。

其中DATASETS也是一个全局变量,程序 一开始执行,DATASETS就会把COCO类和VOC类注册到自己里。最后我们通过打印DATASETS来看到DATASETS里面注册了什么东西。

#例子
class Registry(object):
    def __init__(self,name):
        self.module_dic = {}
        self.name = name
    def __repr__(self):
        return self.name+'_'+str(self.module_dic)
    def register_module(self):
        def _register(cls): #这里的cls就是调用@DATASETS.register_module()的类
            self.module_dic[cls.__name__]=cls
        return _register

DATASETS = Registry('datasets')

@DATASETS.register_module()
class COCO(object):
    def __init__(self):
        pass
@DATASETS.register_module()
class VOC(object):
    def __init__(self):
        pass
print(DATASETS)
#datasets_{'COCO': <class '__main__.COCO'>, 'VOC': <class '__main__.VOC'>}

mmdetection中的DATASETS,BACKBONES等注册表全局变量的形成原理和上面是一模一样的。

 

二、分析实例化注册表全局变量的接口:

从第一节我们可以知道,DATASETS,LOSSES,DETECTOS等等这些注册表全局变量都是把对应的类注册进自己里。就好像DATASETS里面就有{COCO类,VOC类等等},那这些类总得实例化才能用,所以以下就描述实例化这些类的过程。

在第一节提到的build_dataset(),build_backbone(),build_neck(),build_loss()等等就是用于实例化注册表全局变量的接口。

绝大部分的build_xxx()接口都是在文件 /mmdet/models/builder.py 中定义的。

#/mmdet/models/builder.py 
from mmcv.utils import Registry,build_from_cfg
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_backbone(cfg):
    return build(cfg, BACKBONES)

def build_neck(cfg):
    return build(cfg, NECKS)

def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)

def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)

def build_head(cfg):
    return build(cfg, HEADS)

def build_loss(cfg):
    return build(cfg, LOSSES)

def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

从上面的代码可以看出:

1. build_xxx()基本上输入的都是一个cfg,即是配置信息

2. build_xxx()都会调用build(),build()的参数除了cfg配置信息还有的就是对应的注册表全局变量。build()的作用判断传入的配置文件是一个还是多个,传入一个配置文件,就生成一个实例,若传入多个配置文件,就生成多个实例。

我们可以从上图看到,全部build_xxx()最终都是调用build_from_cfg()进行实例化的,build_from_cfg()函数有3个参数(registry、cfg、default_args):

build_from_cfg(cfg, registry, default_args)

registry:某个注册表全局变量(可能是DATASETS的,也可能是LOSSES,BACKBONES,DETECTOS等等)

cfg:从配置文件读取而来的配置信息,不同的 registry ,传入的cfg也是不同的。例如LOSSES的registry,他的cfg的内容就是:

{'type': 'L1Loss', 'loss_weight': 1.0} 

表示实例化LOSSES注册表全局变量里的 ‘L1Loss’损失函数,并且这个损失函数的权重是1.0。我们看看LOSSES里是否有L1Loss这个类:(请看最后一行)

Registry(name=loss, items={
'BalancedL1Loss': <class 'mmdet.models.losses.balanced_l1_loss.BalancedL1Loss'>, 
'CrossEntropyLoss':<class'mmdet.models.losses.cross_entropy_loss.CrossEntropyLoss'>, 
'FocalLoss': <class 'mmdet.models.losses.focal_loss.FocalLoss'>, 
'GHMC': <class 'mmdet.models.losses.ghm_loss.GHMC'>, 
'GHMR': <class 'mmdet.models.losses.ghm_loss.GHMR'>, 
'IoULoss': <class 'mmdet.models.losses.iou_loss.IoULoss'>, 
'BoundedIoULoss': <class 'mmdet.models.losses.iou_loss.BoundedIoULoss'>, 
'GIoULoss': <class 'mmdet.models.losses.iou_loss.GIoULoss'>, 
'MSELoss': <class 'mmdet.models.losses.mse_loss.MSELoss'>, 
'SmoothL1Loss': <class 'mmdet.models.losses.smooth_l1_loss.SmoothL1Loss'>, 
'L1Loss': <class 'mmdet.models.losses.smooth_l1_loss.L1Loss'>}) 

我们可以看到LOSSES里面确实是有一个类的名字叫做L1Loss的(最后一行)。所以cfg包含了要如何实例化的信息。这个L1loss的配置信息算很短的了,我们看一个稍微长一点的,就是BACKBONES注册表全局变量的cfg配置信息:(从下面的配置 信息可以看出来,要实例化的就是ResNet了)

{'type': 'ResNet', 'depth': 50, 'num_stages': 4, 
'out_indices': (0, 1, 2, 3), 
'frozen_stages': 1, 
'norm_cfg': {'type': 'BN', 'requires_grad': True}, 
'norm_eval': True, 'style': 'pytorch'}

default_args:然后说说这个参数,这个参数默认为None,在绝大部分build_xxx()里都不许用传入这个函数,但是有一个是需要输入的,那就是build_detectors(),具体这个参数的作用是包含了训练配置 train_cfg和测试配置test_cfg

(cfg、train_cfg、test_cfg三个配置信息只是有联系的,下一节详细讲一下。)

那具体 build_from_cfg()做了什么呢?

build_from_cfg()具体就是通过传入的配置文件 cfg 来从 registrty 里面挑出对应的类,然后根据cfg的其他配置信息来对这个类进行实例化,当然如果 default_args 不为None的话,就会把 default_args 的信息加入到cfg中,来一起对 registry 的类进行实例化。build_from_cfg()源码:(PS:build_from_cfg()位于mmcv包中的utils.py中,由于mmcv是通过pip 安装的包,不属于mmdetection的项目代码,所以直接修改mmcv中的代码的话,当运行mmdetection是不会有任何改变的

#为了方便看代码,我删掉了一些判断“某个变量是否属于某类型”的语句
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    args = cfg.copy()   #用新的变量args复制cfg配置文件的内容
    obj_type = args.pop('type') #提取要实例化的类的名字
    if is_str(obj_type):
        obj_cls = registry.get(obj_type) #由类名从注册表全局变量中获取这个类
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    ‘’‘#如果default_args不为空,则把defaults_args的内容也加到args中’‘’
    if default_args is not None: 
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args) #主要是这一句,用配置信息args对类进行实例化

三、配置信息 cfg,train_cfg,test_cfg之间的关系

mmdetection是 一个集成了很多目标检测网络的工具箱。所以在训练mmdetection的时候,就必须要制定训练的是哪一个目标检测的网络,那如何制定呢?就通过选择对应目标检测网络的配置文件(这些配置文件都位于目录:/configs/models里)

这些配置文件里,只有三个变量,且都是字典,它们分别是 model,train_cfg,test_cfg

model包含的是这个目标检测模型的一些组成成分,如backbones,roi,rpn,loss等等,并制定了他们的一些参数。主要用于目标检测模型的初始化。以MaskRCNN为例:可以从下面的配置信息看出,model这个字典里面包含了很多有type 的东西,有type都可以看作小组件,因为type的类肯定都会被注册进某个注册表全局变量中的,例如下面代码中的ResNet,就会被在BACKBONES中被找到,而xxx就是它具体的配置信息。我们初始化目标检测网络的时候,就把model作为cfg配置文件传给build_detectors(cfg),但当我们初始化小组件(例如下面的L1Loss)的时候呢?也把这么大的配置文件传进build_loss(cfg)吗?当然不需要,mmdetection的做法是从model里提取出L1Loss的那一部分,把这部分作为配置信息cfg传给 build_loss(cfg)就可以了。

# model settings
model = dict(
    type='MaskRCNN',
    backbone=dict(
        type='ResNet',
        xxx
        ),
    neck=dict(
        type='FPN',
        xxx
        ),
    rpn_head=dict(
        type='RPNHead',
           xxx
        loss_cls=dict(
            type='CrossEntropyLoss', xxx),
        loss_bbox=dict(type='L1Loss', xxx),
    roi_head=dict(
        xxx
        mask_head=dict(
            type='FCNMaskHead',
            xxx
            loss_mask=dict(
                type='CrossEntropyLoss', xxx))))

然后还有 train_cfg和test_cfg两个配置信息,这两个分别是在训练时和测试时设置的一些参数,阈值罢了。只在实例化目标检测模型的时候会被用到。配置信息与build_xxx()的关系如下:model的配置信息是最大的,model的配置信息包含了backbone,loss等组件,如果要实例化组件时,再把组件的配置信息提取出来。

mmdetection构建网络的方式非常解耦,就像是把一个人拆成头,手,脚,身,几个组件,然后对组件进行配置后,再组装成一个整体。

  • 41
    点赞
  • 95
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值