0.背景
1)什么是MMdetection?MMdetection是商汤科技推出的基于pytorch开源目标检测框架,包含了一些经典的目标检测方法。用户可以根据自己的需求,对模型参数进行修改,也可以使用其中的模块(如损失函数、Backbone等)搭建自己的网络。
用户最直接接触的就是config文件,在config文件中,用户主要设置model(模型结构)、datasete(数据集信息)、schedule(训练策略)、default_runtime。同时,各个config文件可以互相直接继承,继承以后再稍加修改就可以得到新的模型。具体细节见文档的五个教程。
2)什么是RetinaNet? RetinaNet是一个单阶段、Anchor-based目标检测器,其结构并没有特别的创新,它的贡献是创造了一个新的损失函数Focal Loss。论文说明了两阶段检测器(如Faster-RCNN)之所以精度比以往的单阶段(YOLO系列)检测器精度高,其中一个原因是正负样本不平衡。两阶段检测器由于会使用RPN对图像进行初步筛选,所以背景之类的负样本会少很多。单阶段检测器由于在训练的时候,负样本过多,导致预测的时候也会倾向于负样本。为此,文章提出的Focal Loss通过稍加修改Cross Entropy损失函数,降低了负样本过多带来的影响,使得单精度检测器的精度可以比肩多精度检测器。
3)这个系列适合谁看? 如果只是想使用RetinaNet训练自己的模型,可以直接看MMdetection的文档,使用起来非常简单,大概一两个小时就能搞懂怎么用和训练自己的数据集。同时本系列不会涉及论文解读,因为网上过多,而且retina创新的东西不多。本系列主要解析MMDetection中关于RetinaNet的源码,通过阅读源码提升实践能力。看这个系列需要首先了解RetinaNet的网络结构,其次了解Pytorch一些基本操作,最后了解一下MMdetection的基本的文件结构。
1. 原型机和总览
在config
文件夹中,包含了若干种已经配置好的模型和数据集配置文件。其中__base__
中有几个关键的模型配置见,其他文件夹中的配置文件往往是从__base__
中衍生出来的。文档中称这些被继承的配置文件叫做Primitive,我翻译成原型机。__base__/model
中有一个配置文件叫做retinanet_r50_fpn.py
按照命名规则可以看出,它使用的是ResNet50做Backbone体征提取,用FPN做Neck特征处理。
retinanet_r50_fpn.py
文件,首先着重看一下model的结构:backbone+neck+bbox_head+loss设置+train+test。所有参数,会在后面慢慢解析,这里只需要知道一下结构。
# model settings
model = dict(
type='RetinaNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256