YOLOv5改进系列(十五) 更换YOLOX解耦头
改进YOLOv5 | 头部解耦 | 将YOLOX解耦头添加到YOLOv5 | 涨点杀器
文章目录
理论
论文地址:https://arxiv.org/abs/2107.08430
YOLOX提出了一个Decoupled Head结构以代替YOLO Head,进而在YOLOv3 baseline的基础上提升了1.1个百分点的mAP,那为什么解耦头结构就能够提升检测效果呢?
我主要在YOLOX原论文讲述Decoupled Head这一部分,找到了引用的两篇文献,并加以解读。
第一篇文献是Song等人在CVPR2020发表的“Revisiting the Sibling Head in Object Detector”。
文中提出了,在目标检测的定位和分类任务中,存在spatial misalignment的问题,我的理解是两个任务所聚焦和感兴趣的地方不同,分类更加关注所提取的特征与已有类别哪一类最为相近,而定位更加关注与GT Box的位置坐标从而进行边界框参数修正。因此如果采取用同一个特征图进行分类和定位,效果会不好,即所谓的misalignment的问题。
下图是原论文的一张图,旨在展示分类和定位所关注的内容是不一致的!
第二篇文献是Wu等人(也是旷视的团队)在CVPR2020发表的“Rethinking Classification and Localization for Object Detection”
文中重新对检测任务中的分类和定位两个子任务进行解读,结果发现:fc-head更适合分类任务,conv-head更适合定位任务,如下面图表结果所示。
第一行是对于分类任务而言,红色是fc-head,蓝色是conv-head,可以看到,在分类的分数上,fc-head显然更具优势,特别对于small objects;
第二行是对于定位任务而言,可以看到,在边界框回归IOU值上,conv-head更具优势。
基于上述的实验结果,该文章设计了一个Double-Head的结构(应该YOLOX的解耦头结构的灵感就是从这里来的),来提升检测的效果。
从实验结果中也可以看到,使用这种Double-Head的结构,可以将mAP提升2-3个百分点,效果还是很不错的!
1. 解耦头原理
在目标检测中,分类任务和回归任务之间的冲突是一个众所周知的问题。因此,用于分类和定位的解耦头被广泛应用于大多数一级和二级探测器。但是,由于YOLO系列的主干和特征金字塔(如FPN, PAN)不断演化,它们的检测头仍然是耦合的,如图2所示。
我们的两个分析实验表明,耦合的检测头可能会损害性能。
- 将YOLO的头部替换为解耦的头部,可以大大提高收敛速度,如图3所示。
- 从表1可以看出,头耦合时端到端属性降低了4.2% AP,而头解耦时,端到端属性降低了0.8% AP。因此,我们将YOLO检测头替换为如图2所示的简化解耦头。具体来说,它包含一个1 × 1的conv层来降低信道维数,然后是两个平行的分支,分别有两个3 × 3的conv层。我们在表2中报告了V100上批处理=1时的推断时间,简化解耦头带来了额外的1.1 ms (11.6 ms vs . 10.5 ms)。
图2:YOLOv3头与解耦头之间的区别示意图。对于FPN的每一层特征,我们首先采用1×1的conv层将特征通道减少到256,然后添加两个并行分支,每个分支有2个3 × 3conv层,分别用于分类和回归任务。在回归分支上添加IoU分支。
2. 解耦头对收敛速度的影响
图3:解耦头的收敛速度比YOLOv3头快得多,最终取得了较好的收敛效果。
3. 解耦头对精度的影响
表1:以AP(%)表示的端到端YOLO解耦头对COCO的影响。
代码改进方式
第一步
yolo.py
中添加如下代码:
class DecoupledHead(nn.Module):
def __init__(self, ch=256, nc=80, width=1.0, anchors=()):
super().__init__()
self.nc = nc # 类别数量,即物体的类别数
self.nl = len(anchors) # 检测层的数量(通常是3)
self.na = len(anchors[0]) // 2 # 每个检测层的锚框数量
self.merge = Conv(ch, 256 * width, 1, 1) # 进行通道数的变换
self.cls_convs1 = Conv(256 * width, 256 * width, 3, 1, 1) # 分类任务的卷积层1
self.cls_convs2 = Conv(256 * width, 256 * width, 3, 1, 1) # 分类任务的卷积层2
self.reg_convs1 = Conv(256 * width, 256 * width, 3, 1, 1) # 回归任务的卷积层1
self.reg_convs2 = Conv(256 * width, 256 * width, 3, 1, 1) # 回归任务的卷积层2
self.cls_preds = nn.Conv2d(256 * width, self.nc * self.na, 1) # 分类预测的卷积层
self.reg_preds = nn.Conv2d(256 * width, 4 * self.na, 1) # 回归预测的卷积层
self.obj_preds = nn.Conv2d(256 * width, 1 * self.na, 1) # 目标置信度预测的卷积层
def forward(self, x):
x = self.merge(x) # 进行通道数的变换
# 分类任务
x1 = self.cls_convs1(x) # 第一个分类任务卷积层
x1 = self.cls_convs2(x1) # 第二个分类任务卷积层
x1 = self.cls_preds(x1) # 分类预测
# 回归任务
x2 = self.reg_convs1(x) # 第一个回归任务卷积层(共享)
x2 = self.reg_convs2(x2) # 第二个回归任务卷积层(共享)
x21 = self.reg_preds(x2) # 回归预测
# 目标置信度任务
x22 = self.obj_preds(x2) # 目标置信度预测
out = torch.cat([x21, x22, x1], 1) # 将三个任务的预测结果拼接在一起
return out
第二步
将self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
注释掉,换成下面这句:
self.m = nn.ModuleList(DecoupledHead(x, nc, 1, anchors) for x in ch)
第三步
将Model
中的self._initialize_biases()
注释掉
第四步
Decoupled=False,
self.decoupled = Decoupled
第五步
修改yaml
文件
(只修改了最下面那一行,多加了一个True
)
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.1 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors,True]], # Detect(P3, P4, P5)
]