一阶段检测模型的构建(以FCOS为例)
引言
对于笔者来说,学习mmdetection最重要的就是学会如何DIY自己的模型,那么弄懂经典的模型是如何一步步搭建出来的就显得非常重要。本章将从零开始,解构一阶段检测模型(以FCOS为例)在mmdetection中的构建过程。本文在FCOS的具体代码部分有详细的讲解和注释,和笔者一样有DIY模型需求的同学可以仔细阅读哦~
一阶段检测模型的构成
一阶段检测模型(OneStageDetector)直接在提取到的特征图上进行目标位置和类别的密集预测。一阶段检测器的模型构成较为简单,关键组件为Backbone, Neck, Bbox_head 三部分:
一般而言,不同一阶段检测模型的差异主要集中在检测头(bbox_head)和损失函数(loss)两部分,这两部分也是本章要讨论的重点。
mmdet中的FCOS模型构建
方法介绍
论文地址:FCOS: Fully Convolutional One-Stage Object Detection
(上图来源于论文原文)
FCOS采用全卷积的网络结构,并且在检测头的分类分支上采取一条额外的中心预测支路来预测目标离中心的偏移程度,能够提升检测边框的质量。这是一篇很有意思的工作,关于模型的具体细节感兴趣的同学可自行阅读原文,本文主要讨论FCOS在mmdetection中的实现过程。
回到上文提到的一阶段检测模型结构“三要素”,FCOS在Backbone上没有特殊的要求,Neck使用的是检测任务中最常用的特征金字塔(FPN),其创新主要体现在bbox_head部分。
mmdet模型一般构建过程
在mmdetection/tools/train.py中,模型是这样构建的:
from mmdet.models import build_detector
"""可以看到,模型的构建取决于cfg.model,cfg.train_cfg,cfg.test_cfg 三个字典"""
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
"""模型参数初始化"""
model.init_weights()
接下来我们看到FCOS的配置文件,笔者对重要的部分做了注释:
"""model部分"""
model = dict(
"""type指示模型的类(class)"""
type='FCOS',
"""指定backbone,backbone一般是模型中最灵活的部分,可以方便地替换为其他网络"""
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron/resnet50_caffe')),
"""这里使用FPN做为neck,并指定了FPN的输入、输出通道数、是否使用relu等参数"""
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output', # use P5
num_outs=5,
relu_before_extra_convs=True),
"""这里指定了bbox_head,可以看到type为'FCOShead',是FCOS的核心组件"""
bbox_head=dict(
type='FCOSHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
"""loss归属于bbox_head部分,这里指定了检测头三个分支的损失函数"""
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
"""训练和测试的配置部分"""
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
配置文件的基本结构如下图所示,FCOShead为此模型的关键部分:
接下来看到检测模型的构建函数build_detector():
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
DETECTORS = MODELS
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
DETECTORS是mmcv的一个注册表实例,DETECTORS.build(cfg)为按照cfg给定的模型类别(关键字’type’)来实例化一个模型。在本例中,模型的类别为 type=‘FCOS’,那么此函数将会构建并返回一个FCOS模型。
FCOS
我们在 mmdetection/mmdet/models/detectors/fcos.py找到FCOS类别的定义:
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class FCOS(SingleStageDetector):
"""Implementation of `FCOS <https://arxiv.org/abs/1904.01355>`_"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained, init_cfg)
出乎意料,FCOS的定义非常简单,就是继承了SingleStageDetector并做实例化而已。
SingleStageDetector
下面来看一下SingleStageDetector吧,这里同样对重点部分做了一些注释:
@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
"""一阶段检测器的基类
一阶段检测器在backbone+neck的输出上直接进行密集的边界框预测
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
"""创建backbone,build函数作用与build_detecor相同"""
self.backbone = build_backbone(backbone)
if neck is not None:
"""创建neck, 检测中一般使用FPN及其各种变体"""
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
"""创建bbox_head"""
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
"""特征提取"""
def extract_feat(self, img):
"""使用backbone和neck提取特征"""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""前向传播算法 x ---> backbone+neck ---> feat ---> bbox_head ---> outs"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
参数:
img (Tensor): 输入图片,形状为(N,C,H,W),一般来说应当是归一化后的图片.
img_metas (list[dict]): 包含'image_scale','flip','filename','ori_shape'等元信息的字典列表
gt_bboxes (list[Tensor]): 边界框真实标注,形状为(xmin,ymin,xmax,ymax).
gt_labels (list[Tensor]): 边界框的类别
gt_bboxes_ignore (None | list[Tensor]): 指定在计算损失时可以被忽略的边界框.
返回值:
dict[str, Tensor]: 包含多个损失函数的字典.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
"""提取特征,并用bbox_head的前向传播函数得到损失"""
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
以上是SingleStageDetector的训练部分,下面看到测试部分:
def simple_test(self, img, img_metas, rescale=False):
"""没有测试阶段数据增强的简单测试函数
参数:
rescale (bool, optional): 是否将检测结果放缩到原有的大小,默认为False.
返回值:
list[list[np.ndarray]]: 每张图片中每一个类别的检测结果,第一个list维度代表不同的图片,第二个list维度代表不同的类别.
"""
"""提取特征并使用bbox_head的simple_test函数进行测试"""
feat = self.extract_feat(img)
results_list = self.bbox_head.simple_test(
feat, img_metas, rescale=rescale)
"""使用bbox2result函数处理检测结果并返回"""
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in results_list
]
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""与simple_test函数基本一致,只是在测试时使用了数据增强
"""
assert hasattr(self.bbox_head, 'aug_test'), \
f'{
self.bbox_head.__class__.__name__}' \
' does not support test-time augmentation'
feats = self.extract_feats(imgs)
"""使用bbox_head的aug_test函数"""
results_list =