0. 前言
-
CenterNet是我很喜欢的一篇论文,直观、好懂。然而,官方的 CenterNet 源码质量真的一般,看过的人应该都有这种感觉。
-
好消息是,MMDetection 中复现了 CenterNet,可以参考这里
-
此外,我想要复现时空行为检测中的 MOC-Detector,这篇文章也是基于 CenterNet 的,所以要捋一捋 CenterNet 源码。
-
MMDetection 的源码工程化非常好,但结构过于复杂,新手非常困难,老手如果长期不用估计也要忘。
-
本文只关注总体结构,不关注一些具体细节(比如gaussian heatmap gt如何实现等)
1. 模型构建
-
MMDetection 中模型的构建主要包括:
- 模型总体结构,管理模型的各个组件以及定义模型训练与测试是的前向流程。
- CenterNet 中就是 CenterNet
- 管理的组件包括
backbone/neck/bbox_head
三个部分
- 模型细节,包括 backbone/neck/bbox_head 的前向细节
- 模型训练相关,包括损失函数与GT的构建。
- 模型总体结构,管理模型的各个组件以及定义模型训练与测试是的前向流程。
-
模型总体结构是
CenterNet
类,其继承结构是 -
除了总体结构外,就是CenterNet的几个基本组件
- 即
backbone,neck,head
,分别是ResNet/CTResNetNeck/CenterNetHead
ResNet
就是普通的残差网络,没啥好说的CTResNetNeck
就是在ResNet后添加了若干 deconv 层,在Deconv前加了DCNv2。
- 即
2. BaseDetector
-
源码在mmdet.models.detectors.base.py中,主要作用就是定义一些所有检测器都会用到的功能。
-
抽象类,其他所有检测模型都会继承该对象。换句话说,检测的主要功能这个函数就全部定义完了,子类要做的就是实现里面的方法。
-
功能主要可以划分为:
- 判断组件是否存在,即
with_xxx
函数,其中,xxx 可以的取值有neck/shared_head/bbox/mask
- 定义模型前向的基本流程,根据 train/test 分别定义,后面会详细介绍这部分。
- 定义模型训练、验证时的流程,如获取模型结果、计算损失函数,即
train_step/val_step
,后面详细介绍这部分。 - 展示模型结果,即在 img 上画 bbox 和 labels,
show_result
函数 - 模型导出为 ONNX 格式,即
onnx_export
。
- 判断组件是否存在,即
-
模型前向推理:
-
入口函数是
forward
- 为什么是入口?这个是
nn.Module
中定义的,__call__
函数会调用forward
函数。 - 该函数根据模式分别调用
forward_train
和forward_test
函数。
- 为什么是入口?这个是
-
训练时函数前向流程入口函数
forward_train
- 这个函数一般会在子类中重写。
- 该函数的结果一般就是各种损失函数
-
验证/测试时前向流程入口
forward_test
-
有TTA则调用
aug_test
-
没有TTA则调用
simple_test
-
上面两个函数都是抽象函数,子类继承实现
-
该函数的结果一般是模型预测结果(经过后处理),而不包括损失函数
-
-
定义了特征提取抽象函数
extract_feat
和extract_feats
,一般会在forward_train/forward_test
的具体实现中引用特征提取这两个函数。
-
-
模型训练、验证时的流程
- 主要就是
train_step
和val_step
- 这一部分与前面
模型前向推理
的区别在于,在调用了模型前向推理函数(即model(**data)
)后,还会对模型结果进行一些封装。换句话说,就是对forward_train
的结果进行一些封装。 - 所谓封装,一般也就是调用
_parse_losses
函数,就是解析各种loss,封装成 dict 并累加求和 - 每次训练、验证的时候就需要调用者两个函数,主要问题就在于,什么时候调用。
- openmmlab 中,训练和测试都会使用 Runner 实现,而在Runner中就对调用者两个函数,如源码所示
- 主要就是
3. SingleStageDetector
-
源码在
mmdet.models.detectors.single_stage
中 -
定义了所有单阶段目标检测器的基本功能与流程。
-
细节上看,就是重写了
simple_test/aug_test/extract_feat/forward_train/forward_dummy/onnx_export
几个函数。 -
特征提取流程:就是 backbone + neck,没有啥好说的。
-
单阶段目标检测训练时流程:
- 特征提取+head前向
- 计算损失函数都是在 head 中定义的
-
无TTA测试流程
- 特征提取+head前向
- head前向中就是获取bbox,没有其他损失函数相关
-
有TTA 测试流程
- 特征提取+head前向
- head前向中实现tta
-
从上面可以看到,MMDetection 中的 head 实现了很多功能,包括
- 普通前向,获取预测结果
- 训练时,GT 构建,与预测结构匹配,并计算损失函数
- 测试时,对预测结果后处理(如NMS),获取最终结果
- 处理 TTA 的细节
- 从源码上看,head 需要有
forward_train
获取损失函数,get_bboxes
实现后处理获取检测框,aug_test
实现TTA
-
其实
SingleStageDetector
类已经比较完善了,只要导入各种backbone、neck、head 就能实现单目标检测功能了。
4. CenterNetHead
- CenterNet 实现的关键,主要功能包括:
- 普通前向,获取模型预测结果。
- 前向+后处理,获取过滤后的模型结果。
- 前向+损失函数/获取GT等。
- 前向+TTA
CenterNetHead
继承了BaseDenseHead
和BBoxTestMixin
BaseDenseHead
的主要功能就是定义了一个 head 应该做哪些工作losses
:计算损失函数get_bboxes
:根据模型结果获取 bboxes,包括后处理以及模型结果解析forward_train
:调用前面的 loss 函数,管理计算损失函数的过程simple_test
:定义基本前向过程
BBoxTestMixin
是个 Mixin 函数,不太懂,感觉就是包含一堆可复用的作为对象让别人来集成?- 主要就是 TTA 相关以及rpn相关
CenterNetHead
主要就是实现了loss/forward/get_bboxes
两个函数- loss 的功能主要包括:
- 根据
gt_bboxes/gt_labels
获取与预测结果一一对应的 GT。比如 heatmap 对应的就是符合高斯分布的圆。 - 分别计算几个分支的损失函数。
- 根据
forward
的主要功能包括:- 根据几个 head,以特征提取结果作为输入,获取模型最终预测结果
get_bboxes
的主要功能就是将模型预测结果转换为 bboxes- 大概就是解析 heatmap、进行NMS 等。
- loss 的功能主要包括: