文章来源
paper: https://arxiv.org/pdf/2005.12872.pdf,
code: https://github.com/facebookresearch/detr
Motivation
作者提出了一种新的方法,将目标检测视为一个直接预测的问题。新框架称为DETR,主要组成部分是一个基于集合(set-based)的全局损失用来强制通过双部匹配得到唯一预测,以及一个Transformer编码器-解码器架构。给定一个固定的小的学习目标查询集,DETR找出目标和全局图像上下文之间的关系,从而直接并行输出最终的预测集结果。
方法
主要由2部分组成:a set prediction loss 以及新的检测器架构(基于transformer)
下面主要分析DETR模型架构:
DETR分为3部分:CNN (backbone)用于提取目标特征,encoder-decoder transformer 以及feed dorward netword (FFN).
(1)Backbone:采用了resnet50 ===>
f
x
∈
C
∗
H
∗
W
f_x \in C*H*W
fx∈C∗H∗W (C=2048)
(2)Transformer-encoder:
f
x
⇒
1
∗
1
c
o
n
v
⇒
f
x
∗
∈
d
∗
H
W
f_x ⇒ 1*1 conv ⇒ f_x^* \in d*HW
fx⇒1∗1conv⇒fx∗∈d∗HW
每个encoder都由multi-head self-attention module and a feed forward netword构成。
(3)Transformer-decoder:
参考了标准的transformer-decoder,每一个输出最终都分别和下一层的FFN连接。
(4)FFN: 采用了3层感知机实现。每个decoder的输出都连接到共享的FFN。由于预测的框数目为N,大于图像中目标的个数,于是一个额外的类别
∅
\varnothing
∅ 视为该框预测的是背景,而不是目标。
结论
作者首次将transformer引入到图像检测领域。在transformer处没有做创新,而是直接搬运了transformer的内容做encoder和decoder工作。
预测代码如下:
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)