RT-DETR模型代码图解分析

RT-DETR很像detr系列发展的成果总结, 对应于yolov4的发布阶段。v4总结之前各种调参工艺,使得yolo系列在当时取得SOTA结果; RT-DETR工作也融合当前detr各种研究成果,使得fps与AP获得与yolov8性能相匹配。
源代码使用了torchvision的代码模块,版本必须是0.15.2, 其他版本都会报错(连0.15.1都会出错在这里插入图片描述)

论文链接
https://arxiv.org/pdf/2304.08069.pdf

源码链接
https://github.com/lyuwenyu/RT-DETR

当前项目链接
https://github.com/yblir/RT-DETR

部分结构图参考
https://www.bilibili.com/video/BV1Nb4y1F7k9/?spm_id_from=333.788&vd_source=e6c58a8524018df4b96578742d7ddbde

一 RT-DETR整体网络结构

rtdetr模型结构可以与官方代码完美对应, 如下图所示.
代码主干逻辑非常清晰,每个模块的输入/输出shape已经在图中展示出来了. 接下来将分别解析backbone, encoder, decoder三个部分!

  • src/zoo/rtdetr/rtdetr.py
    在这里插入图片描述

二 backbone模块

backbone之一是对resnet的魔改,命名为presnet,主要修改两点:
第一是把开始阶段的7x7卷积改为三组3x3卷积,通过调整步长使得输出shape与原来的7x7保持一致,这样可以更好地提取特征。
第二是把resnetblock中1x1,步长为2的池化模块替换为步长为2的全局池化,然后1x1调整通道数,这样做的目的是原来的下采样方式会丢失信息,修改后更多地保留信息。
backbone不是我们的重点,这里不过多关注,因为可以替换成其他网络,比如换成yolov8精心调参过的backbone效果会怎样呢?

ps: 官方代码只放出了presnet, 性能最好的backbone–HGNetv2并没在源码中给出。
在这里插入图片描述

假设输入数据shape为bs,3,576,576,会分别从stage2,3,4获得s3,s4,s5特征。

  • src/nn/backbone/presnet.py
    在这里插入图片描述

三 Encoder模块

encode对应结构图中efficient hybrid encoder(高效的混合编码)

3.1 AIFI
使用backbone输出的s5进行transformer编码

  • src/zoo/rtdetr/hybrid_encoder.py (class HybridEncoder)
    一些变量说明:
    self.input_proj是一个包含3个线性结构的ModuleList,作用是将s3,s4,s5的统一输出通道数,为了以后的尺度间融合
    self.encoder是transformer的编码器结构,代码中套了好几层,模型结构如下所示。在detr和deformable detr中encode是6层,这里为加快运行速度,只使用了一层encoder。
    在这里插入图片描述

在这里插入图片描述

关于AIFI编码器 memory = self.encoder[i](src_flatten, pos_embed=pos_embed) 的实现部分如下:
在当前模块的TransformerEncoderLayer类中
在这里插入图片描述

  • 3.2 CCFM 实现过程
    CCFM可以认为就是个PAFPN。关于Fusion模块,对应下面代码中的self.fpn_blocks和self.pan_blocks, 虽然是两个变量名不同,当装的都是CSPRepLayer层,关于Fusion,下面单独解析。
    在这里插入图片描述
    关于Fusion模块
    在这里插入图片描述

四 decoder模块

  • src/zoo/rtdetr/rtdetr_decoder.py

decode部分比encode复杂很多,整合多种detr相关技术.

以下代码中, 未框起来的代码是对输出结果的整合
在这里插入图片描述

4.1 解析 (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
下面self.input_proj是三组1x1卷积。
在这里插入图片描述
_get_encoder_input作用:将encoder输出的三个特征拍平拼在一起,作为decoder的输入,并且记录哪些token索引对应原来的特征层。
在这里插入图片描述

4.2 解析get_decoder_input() : IOU-aware+decode输入
解决类别置信度与位置置信度表现不一样的问题.
iou-aware做法是让低iou的预测box有低类别置信度. 高iou的box有高类别置信度.
在这里插入图片描述

4.3 def get_contrastive_denoising_training_group()
模块来自DN-DETR模型, 方法是,对真是框进行加噪(类别+坐标)后, 作为decoder的输入的一部分.这样decoder的预测输出(query)就会非常明确知道自己在预测哪个object, 因为就是由那个目标加噪得来的. 这样就能加速训练收敛过程. 在推理阶段是没有加噪的,在4.2节代码解析中有体现.
这个代码模块可以当做一个有效的黑盒模块使用, 不想深究了, 以后有兴趣再补充吧!
在这里插入图片描述
4.4 decoder模块

模块来自Deformable DETR.

部分变量说明:
memory_mask: 哪写memory中的query是pading出来的.
score_head: ModuleList, 对每个docoder层输出进行类别预测
bbox_head: ModuleList, 对每个docoder层输出进行坐标预测

在这里插入图片描述

解析layer层:
主要模块Multi-scale Deformable-Attention 来自Deformable DETR模型, 这里不展开了, 有兴趣了以后再补充!
在这里插入图片描述

评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值