RT-DETR很像detr系列发展的成果总结, 对应于yolov4的发布阶段。v4总结之前各种调参工艺,使得yolo系列在当时取得SOTA结果; RT-DETR工作也融合当前detr各种研究成果,使得fps与AP获得与yolov8性能相匹配。
源代码使用了torchvision的代码模块,版本必须是0.15.2, 其他版本都会报错(连0.15.1都会出错)
论文链接
https://arxiv.org/pdf/2304.08069.pdf
源码链接
https://github.com/lyuwenyu/RT-DETR
当前项目链接
https://github.com/yblir/RT-DETR
一 RT-DETR整体网络结构
rtdetr模型结构可以与官方代码完美对应, 如下图所示.
代码主干逻辑非常清晰,每个模块的输入/输出shape已经在图中展示出来了. 接下来将分别解析backbone, encoder, decoder三个部分!
- src/zoo/rtdetr/rtdetr.py
二 backbone模块
backbone之一是对resnet的魔改,命名为presnet,主要修改两点:
第一是把开始阶段的7x7卷积改为三组3x3卷积,通过调整步长使得输出shape与原来的7x7保持一致,这样可以更好地提取特征。
第二是把resnetblock中1x1,步长为2的池化模块替换为步长为2的全局池化,然后1x1调整通道数,这样做的目的是原来的下采样方式会丢失信息,修改后更多地保留信息。
backbone不是我们的重点,这里不过多关注,因为可以替换成其他网络,比如换成yolov8精心调参过的backbone效果会怎样呢?
ps: 官方代码只放出了presnet, 性能最好的backbone–HGNetv2并没在源码中给出。
假设输入数据shape为bs,3,576,576,会分别从stage2,3,4获得s3,s4,s5特征。
- src/nn/backbone/presnet.py
三 Encoder模块
encode对应结构图中efficient hybrid encoder(高效的混合编码)
3.1 AIFI
使用backbone输出的s5进行transformer编码
- src/zoo/rtdetr/hybrid_encoder.py (class HybridEncoder)
一些变量说明:
self.input_proj是一个包含3个线性结构的ModuleList,作用是将s3,s4,s5的统一输出通道数,为了以后的尺度间融合
self.encoder是transformer的编码器结构,代码中套了好几层,模型结构如下所示。在detr和deformable detr中encode是6层,这里为加快运行速度,只使用了一层encoder。
关于AIFI编码器 memory = self.encoder[i](src_flatten, pos_embed=pos_embed) 的实现部分如下:
在当前模块的TransformerEncoderLayer类中
- 3.2 CCFM 实现过程
CCFM可以认为就是个PAFPN。关于Fusion模块,对应下面代码中的self.fpn_blocks和self.pan_blocks, 虽然是两个变量名不同,当装的都是CSPRepLayer层,关于Fusion,下面单独解析。
关于Fusion模块
四 decoder模块
- src/zoo/rtdetr/rtdetr_decoder.py
decode部分比encode复杂很多,整合多种detr相关技术.
以下代码中, 未框起来的代码是对输出结果的整合
4.1 解析 (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
下面self.input_proj是三组1x1卷积。
_get_encoder_input作用:将encoder输出的三个特征拍平拼在一起,作为decoder的输入,并且记录哪些token索引对应原来的特征层。
4.2 解析get_decoder_input() : IOU-aware+decode输入
解决类别置信度与位置置信度表现不一样的问题.
iou-aware做法是让低iou的预测box有低类别置信度. 高iou的box有高类别置信度.
4.3 def get_contrastive_denoising_training_group()
模块来自DN-DETR模型, 方法是,对真是框进行加噪(类别+坐标)后, 作为decoder的输入的一部分.这样decoder的预测输出(query)就会非常明确知道自己在预测哪个object, 因为就是由那个目标加噪得来的. 这样就能加速训练收敛过程. 在推理阶段是没有加噪的,在4.2节代码解析中有体现.
这个代码模块可以当做一个有效的黑盒模块使用, 不想深究了, 以后有兴趣再补充吧!
4.4 decoder模块
模块来自Deformable DETR.
部分变量说明:
memory_mask: 哪写memory中的query是pading出来的.
score_head: ModuleList, 对每个docoder层输出进行类别预测
bbox_head: ModuleList, 对每个docoder层输出进行坐标预测
解析layer层:
主要模块Multi-scale Deformable-Attention 来自Deformable DETR模型, 这里不展开了, 有兴趣了以后再补充!