OpenMMLab实训营(五)MMDetection

  1. 框架概述
      MMDetection是商汤和港中文大学针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块,通过模块调用的方式,我们能够以很少的代码量实现一个新算法,大大提高了代码复用率。

整个MMLab家族除了MMDetection,还包含针对目标跟踪任务的MMTracking,针对3D目标检测任务的MMDetection3D等开源项目,他们都是以Pytorch和MMCV以基础。Pytorch不需要过多介绍,MMCV是一个面向计算机视觉的基础库,最主要作用是提供了基于Pytorch的通用训练框架,比如我们常提到的Registry、Runner、Hook等功能都是在MMCV中支持的。另外,MMCV还提供了通用IO接口、多种CNN网络结构、高质量实现的常见CUDA算子,这里就不进一步展开了。

2.2 MMDetection
  使用Pytorch构建一个新算法时,通常包含如下几步:

注册数据集:CustomDataset是MMDetection在原始的Dataset基础上的再次封装,其__getitem__()方法会根据训练和测试模式分别重定向到prepare_train_img()和prepare_test_img()函数。用户以继承CustomDataset类的方式构建自己的数据集时,需要重写load_annotations()和get_ann_info()函数,定义数据和标签的加载及遍历方式。完成数据集类的定义后,还需要使用DATASETS.register_module()进行模块注册。
注册模型:模型构建的方式和Pytorch类似,都是新建一个Module的子类然后重写forward()函数。唯一的区别在于MMDetection中需要继承BaseModule而不是Module,BaseModule是Module的子类,MMLab中的任何模型都必须继承此类。另外,MMDetection将一个完整的模型拆分为backbone、neck和head三部分进行管理,所以用户需要按照这种方式,将算法模型拆解成3个类,分别使用BACKBONES.register_module()、NECKS.register_module()和HEADS.register_module()完成模块注册。
构建配置文件:配置文件用于配置算法各个组件的运行参数,大体上可以包含四个部分:datasets、models、schedules和runtime。完成相应模块的定义和注册后,在配置文件中配置好相应的运行参数,然后MMDetection就会通过Registry类读取并解析配置文件,完成模块的实例化。另外,配置文件可以通过_base_字段实现继承功能,以提高代码复用率。
训练和验证:在完成各模块的代码实现、模块的注册、配置文件的编写后,就可以使用./tools/train.py和./tools/test.py对模型进行训练和验证,不需要用户编写额外的代码。

配置文件

test_pipeline = [
    dict(type=‘LoadMultiViewImageFromFiles’, to_float32=True),
    dict(type=‘NormalizeMultiviewImage’, **img_norm_cfg),
    dict(type=‘PadMultiViewImage’, size_divisor=32)
]
evaluation = dict(interval=2, pipeline=test_pipeline)
 配置文件也支持继承操作,通过_base_变量来实现。_base_是一个list类型变量,里面存储的是要继承的配置文件的路径。在解析配置文件的时候,文件解析器以递归的方式(其他配置文件也可能包含_base_变量)解析所有配置文件。任何配置文件往上追溯都会继承以下四个文件,分别对应数据集(datasets)、模型(models)、训练策略(schedules)和运行时的默认配置(default_runtime):

base = [
    ‘mmdetection/configs/base/models/fast_rcnn_r50_fpn.py’,        # models
    ‘mmdetection/configs/base/datasets/coco_detection.py’,        # datasets
    ‘mmdetection/configs/base/schedules/schedule_1x.py’,            # schedules
    ‘mmdetection/configs/base/default_runtime.py’,                # defualt_runtime
]

1. 模型配置(models) =========================================

model = dict(
    type=‘FastRCNN’,            # 模型名称是FastRCNN
    backbone=dict(                # BackBone是ResNet
        type=‘ResNet’,
        …,
    ),
    neck=dict(                    # Neck是FPN
        type=‘FPN’,
        …,
    ),
    roi_head=dict(                # Head是StandardRoIHead
        type=‘StandardRoIHead’,
        …,
        loss_cls=dict(…),        # 分类损失函数
        loss_bbox=dict(…),    # 回归损失函数
    ),
    train_cfg=dict(                # 训练参数配置
        assigner=dict(…),        # BBox Assigner
        sampler=dict(…),        # BBox Sampler
        …
    ),
    test_cfg =dict(                # 测试参数配置
        nms=dict(…),            # NMS后处理
        …,
    )
)

2. 数据集配置(datasets) =========================================

dataset_type = ‘…’            # 数据集名称
data_root = ‘…’                # 数据集根目录
img_norm_cfg = dict(…)        # 图像归一化参数
train_pipeline = [                # 训练数据处理Pipeline
    …,
]
test_pipeline = […]            # 测试数据处理Pipeline
data = dict(
    samples_per_gpu=2,            # batch_size
    workers_per_gpu=2,            # GPU数量
    train=dict(                    # 训练集配置
        type=dataset_type,
        ann_file=data_root + ‘annotations/instances_train2017.json’,    # 标注问加你
        img_prefix=data_root + ‘train2017/’,    # 图像前缀
        pipline=trian_pipline,                    # 数据预处理pipeline
    ),
    val=dict(                    # 验证集配置
        …,
        pipline=test_pipline,
    ),
    test=dict(                    # 测试集配置
        …,
        pipline=test_pipline,
    )
)

3. 训练策略配置(schedules) =========================================

evaluation = dict(interval=1, metric=‘bbox’)
optimizer = dict(type=‘SGD’, lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy=‘step’,
    warmup=‘linear’,
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
runner = dict(type=‘EpochBasedRunner’, max_epochs=12)

4. 运行配置(runtime) =========================================

checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type=‘TextLoggerHook’)])
custom_hooks = [dict(type=‘NumClassCheckHook’)]
dist_params = dict(backend=‘nccl’)
log_level = ‘INFO’
load_from = None
resume_from = None
workflow = [(‘train’, 1)]

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值