1. 前言
开源目标检测框架mmdetection
在2D/3D
的目标检测上受到越来越多使用者的青睐。作为一个半商业半学术的开源框架,它的很多内部细节是值得学习和借鉴的。同时,没有一个框架是绝对完美的,但是站在巨人的肩膀上无疑会看的更远。小白我认识mmdetection
框架是源于对检测算法SA-SSD
的研究。在使用和调试SA-SSD
和领略mmdetection
框架的便捷高效之后,我打算认真分析mmdetection
的架构。这是我将要写作的系列博客的核心原因。考虑到我主要做3D
目标检测之一块,所以代码的阅读也是以3D
目标检测为主。mmdetection
的官方文档可点击这里进入。
2. 背景
mmdetection
框架是基于pytorch
框架建立的。使用mmdetection
框架做训练,本质上还是调用了pytorch
的深度学习训练的API
。但是呢,随着目标检测近几年的发展,它的算法组成模块也逐渐固定下来,大致分为Backbone
,Neck
,以及Head
这么几个部分,如图1所示。其实,这对基于Anchor
还是Anchor-free
,或者对单阶段还是多阶段检测,都是适用的。如果目标检测是基于Anchor
的,Anchor
的生成以及选优等过程也是一套固定的流程。所有的目标检测算法都会使用非极大值抑制(NMS
)技术得到不重叠的2D/3D
目标框。NMS
的算法框架也是固定的。
图1:单阶段和多阶段目标检测的示意图
目标检测的误差函数也逐渐固定下来。对于检测目标分类使用One-hot
编码的交叉熵损失函数,对于正负样本不均衡情况则改用Focal Loss
误差函数;对于目标检测位置则使用Smooth-L1
损失函数。当然,有些3D
目标检测的位置误差函数使用一些不太一样的形式(“粗糙分类+精细回归”的模式),比如PointRCNN
和F-PointNet
。但是它们最后都可以拆解为交叉熵损失函数(对应“粗糙分类”)和Smooth-L1
损失函数(对应“精细回归”)。
除此之外,目标检测的主要数据集是COCO
和KITTI
,在短期之内不会有变化。对原始数据集做处理的流程也是固定的。以及目标检测的衡量指标,计算2d/3d iou
指标和计算AP
的流程也是相对固定的。
其实我们还忽略了商业工业上的因素。对于自动驾驶技术甚至是安防工业等项目,对目标检测技术是高度需求的。以自动驾驶为例说明(对实时性有极高地要求),对车类和行人目标的检测,有助于提供车辆避障和路径规划的稳定性。同样对于安防领域(对实时性要求较低),需要监视一块区域是否有无关物体闯入闯出等等。传统的目标检测方法不能够适用于复杂多变的场景,这为基于深度学习的目标检测提供了发展机会。
总而言之,可以发现,2d/3d
目标检测的各个模块,误差函数,评价指标,以及线下数据集都是高度成熟和固定的。因此,无论是对学术界还是工业界的研究者,在研究或开发一个目标检测算法的前期,实在是没有必要把这些重复性的高度成熟的轮子再造一遍,更何况这些模块的代码量也是巨大的(后期可能会做工程上的优化)。所以呢,我们需要一个工具框架,帮助我们实现那些固定化的模块,能够让我们轻松地调用,能够让网络训练地更快更准。这就是mmdetection
框架设计的初衷。同时mmdetection
框架有较好的可扩展性,方便使用者添加自定义的模块。
在这一节的最后,来说点题外话,就是2d/3d
目标检测的高度成熟,真的就代表目标检测算法已经在真实场景下有了出色的表现吗?答案显然是否定的。对于工业应用,最为核心的问题恐怕是“数据”这一块。深度学习算法都是数据驱动的(data dirven
)。为了提供目标检测的性能,就需要更多的训练数据,公司就得在数据标注上下更大的功夫,比如设计更为简便的标注工具,雇佣更多的标注人员。对于工业应用,最重要的问题是“实时鲁棒”,即目标检测算法长时间运行对整个自动驾驶系统不会有不良影响。解决这个问题,就需要聘请算法工程师对核心模块重构,编写轻量稳定的新架构,也需要测试开发工程师写前端界面测试环境等。对于工业应用,最为关键的问题是“网络架构”,需要高端算法工程师和算法研究员的参与。不过这个高端职位应该是留给手握顶会的求职者吧。这一段的讨论归结为图2。
图2:以目标检测为例的AI
部门团队的人员划分(项目经理也应该放在统筹的位置)
3. 架构
个人认为,mmdetection
框架大致分为如下的七个主要模块:
- 训练推断模块:它位于
mmdet/apis
,提供网络训练和推断的框架代码,能够根据网络配置文件cfg
的要求,初始化目标检测网络以及设定优化器模式,提供所需超参数。 - 数据输入模块:它位于
mmdet/datasets
,用于提供做训练的各个类型数据集(比如kitti
和coco
)的输入格式并支持数据增广。对点云的数据增广则位于mmdet/core/point_cloud
。 - 网络架构模块:它位于
mmdet/models
,用于提供目标检测网络常用的Backbone Network
,Neck Network
,以及Detection Head
等等。目标检测网络使用的非极大值抑制则存放在mmdet/core/post_processing
中。 - 指标计算模块:它位于
mmdet/core
,用于提供2d/3d iou
的计算,以及PR
曲线和AP
值的计算。 - 损失函数模块:它位于
mmdet/core/loss
,用于提供目标检测中常用的误差损失函数,比如分类使用的交叉熵损失函数和focal loss
,3d
框回归使用的平滑l1
范数等等。 Anchor
生成模块:它位于mmdet/core/anchor
,用于给基于anchor
的目标检测算法提供anchor
。- 自定义运算模块:它位于
mmdet/ops
,用于添加自定义的点云处理模块,比如PointNet++
一些运算。
4. 结束语
这篇博客首先讨论了mmdetection
产生的背景,然后概括了mmdetection
的七个模块的主要内容。