DETR:End-to-End Object Detection with Transformers
- Encoder:Transformer encoder
- Decoder:Transformer decoder
- GT与预测Queries匹配:匈牙利匹配算法
- 损失计算:分类 + L1边框回归 + GIoU边框回归
- 源码地址:https://github.com/facebookresearch/detr
3. 前向推理过程
Debug mian.py文件大约到193行的时候进入每个epoch的训练,其中一个epoch的训练在train_one_epoch中
在train_one_epochj函数中首先是将模型切换为训练模式,之后记录一些配置信息,随后在data_loader取出一个epoch的图片
log_every函数用于取出data_loader对象,这个操作对应于225行的yield obj (log_every函数在misc.py文件中)
在得到一个批次的data_loader后,得到sample(输入图像)和targets(GroundTruth的一些信息),之后下一步就是将samples送到DETR结构中得到序列的分类和边框坐标位置输出
3.1 特征提取 + 分类回归输出
在DETR类的forward函数中执行:
- samples经过backbone得到features和positional embedding
- features解耦出src(一张图片经过ResNet的输出)和mask
- input_proj(src)将2048维特征映射到256维特征
- hs是256维特征经过transformer encoder和decoder的输出,注意维度是[6, b, 100, 256]。其中的6是decoder的层数,b是批次,100是Queries的长度
- outputs_class是输出输出特征图预测的类别信息:[6, b, 100, 92] (92 = 90个类别 + person + 空)
- outputs_coord是输出输出特征图预测的边框信息:[6, b, 100, 4]
- out里面包含:pred_logits(维度为[1, b, 100, 92],-1只取decoder最后一层输出 ),pred_boxes([6, b, 100, 4] )以及aux_outputs的信息,其中aux_outputs是其他5层decoder的信息,每个都包含pred_logits和pred_boxes
彩蛋
本文对DETR代码的模型构建和初始化进行详细阐述,笔者会持续分享DETR解析系列,笔者也建立了一个关于目标检测的交流群:465411015,欢迎大家踊跃加入,一起学习鸭!