论文阅读笔记:End-to-End Object Detection with Transformers
前言
目标检测(Object Detection)如RCNN系列或YOLO系列都是使用滑动窗口的方式通过先验覆盖整幅图可能出现目标的部分来进行先验框偏移量的预测从而找到目标。
而加入Transformers后,可以得到pixel序列之间的依赖关系确定整幅图的实例关系,从而使得模型有选择地聚焦于输入的某些部分。
相比之下,Detection Transformers似乎更直接,他无需去设计检测的方法,从输入到输出实现更本质的目标检测。
Transformer介绍博客:论文阅读笔记:Attention Is All You Need
论文原文:End-to-End Object Detection with Transformers
代码下载地址: Github
Detection Transformer
整体流程
为了填补现有目标检测框架中存在的重复预测后处理步骤、anchor集合设计以及anchor与gt的分配策略等,作者提出了采用了基于transformers的编解码框架,该框架被广泛应用于序列预测。transformers中的自注意力机制使该框架非常适合于集合预测的特定约束比如移除重复预测等,自注意力机制显示的建模了序列中元素对之间的交互关系。
该文章提出的目标检测思路是通过CNN进行特征提取,再以特征图的通道为一向量,将
h
×
w
h×w
h×w大小的特征图转换为
h
×
w
h×w
h×w个特征向量,把特征向量集传入transformer模型来完成预测固定个数的结果,训练过程采用二分图匹配损失函数(bipartite matching loss)。
二分图匹配
二分图匹配就是将Boundingbox和Groundtruth进行匹配。模型会输出固定数量的预测box,对于预测值数量不够的情况会用 ∅ 补齐,标签的Groundtruth会和输出的值数量一致,不够的也会用 ∅ 补齐。
如下图所示:
直接使用集合损失(使用匈牙利算法来计算),即便预测输出同一目标多次,也只能有一个GT与之对应,类似NMS。最终模型通过训练就会学会在两边给出同样个数的无类别预测 ∅,因为如果给出一个不应该的到的输出都会受到惩罚。比如预测的两个物体和GT中的两个标签能够对应上,模型还给出第三个物体的预测,但GT中已经没有标签可以与之对应了就应该受到惩罚,因为剩下的预测应该是无类别的。这依赖于输出的预测数量必须固定且GT也应该补充到和其一致。
关于其匹配原理和代码可查看博文:算法讲解:二分图匹配【图论】。
目标检测集合预测损失
目标检测中使用直接集合预测最关键的两个点是:1)保证真实值与预测值之间唯一匹配的集合预测损失。2)一个可以预测(一次性)目标集合和对他们关系建模的架构。
二分图匹配损失是真实值与预测值之间两两匹配的Loss。使用匈牙利算法来计算。匹配损失同时考虑到类别与真实值与预测值之间的相似度,
y
=
(
c
i
,
b
i
)
y=(c_i,b_i)
y=(ci,bi)其中
c
c
c 是目标的类别,
b
b
b 是值域在[0,1]的四维向量,bbox的中心坐标与宽高。
bbox损失直接使用L1loss的话,对小目标就不公平,因此使用了L1 loss 与IOU loss的组合,让loss对目标的大小不敏感。
损失函数的具体形式如下:
与常见的检测模型很相似,就是负对数似然与box损失的线性组合。其中
σ
^
\hat{σ}
σ^是二分图匹配损失中求得的最优匹配。类似于faster-rcnn对负样本权重的设置,当
c
i
=
∅
c_i=∅
ci=∅ 时,权重为原来的十分之一。目标与 ∅ 的匹配损失不依赖于预测值,因此是一个常量。在匹配损失中,使用概率去代替对数概率。这样是为了平衡类别预测与box预测的损失,效果更好。
网络结构
网络结构如下:
BackBone:CNN,输入
x
i
m
g
∈
R
C
0
∗
H
0
∗
W
0
x_{img}∈R^{C_0*H_0*W_0}
ximg∈RC0∗H0∗W0,输出
f
∈
R
C
∗
H
∗
W
f∈R^{C*H*W}
f∈RC∗H∗W,
H
=
H
0
/
32
,
W
=
W
0
/
32
H=H_0/32,W=W_0/32
H=H0/32,W=W0/32。
Transformer相关内容(encoder.,decoder,FFNs)可查看前一篇博文:论文阅读笔记:Attention Is All You Need
附加的解码loss:使用附加的loss对模型的训练有帮助,每一个decoder层后面加上FFNs和匈牙利loss。所有FNNs共享权重。使用共享的 layer-norm 去归一化不同decoder层的输出。
Pytorch Inference Code
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
可以看出Positional Encoding(位置编码)是torch.rand(50, hidden_dim // 2)随机生成出来的。而不像上一篇论文中提到的使用COS/SIN函数来分割出来。