视觉Transformer (二) End-to-End Object Detection with Transformers

文章来源

paper: https://arxiv.org/pdf/2005.12872.pdf,
code: https://github.com/facebookresearch/detr

Motivation

作者提出了一种新的方法,将目标检测视为一个直接预测的问题。新框架称为DETR,主要组成部分是一个基于集合(set-based)的全局损失用来强制通过双部匹配得到唯一预测,以及一个Transformer编码器-解码器架构。给定一个固定的小的学习目标查询集,DETR找出目标和全局图像上下文之间的关系,从而直接并行输出最终的预测集结果。

方法

主要由2部分组成:a set prediction loss 以及新的检测器架构(基于transformer)
下面主要分析DETR模型架构:
在这里插入图片描述
DETR分为3部分:CNN (backbone)用于提取目标特征,encoder-decoder transformer 以及feed dorward netword (FFN).

(1)Backbone:采用了resnet50 ===> f x ∈ C ∗ H ∗ W f_x \in C*H*W fxCHW (C=2048)
(2)Transformer-encoder: f x ⇒ 1 ∗ 1 c o n v ⇒ f x ∗ ∈ d ∗ H W f_x ⇒ 1*1 conv ⇒ f_x^* \in d*HW fx11convfxdHW
每个encoder都由multi-head self-attention module and a feed forward netword构成。
(3)Transformer-decoder:
参考了标准的transformer-decoder,每一个输出最终都分别和下一层的FFN连接。
(4)FFN: 采用了3层感知机实现。每个decoder的输出都连接到共享的FFN。由于预测的框数目为N,大于图像中目标的个数,于是一个额外的类别 ∅ \varnothing 视为该框预测的是背景,而不是目标。

结论

作者首次将transformer引入到图像检测领域。在transformer处没有做创新,而是直接搬运了transformer的内容做encoder和decoder工作。

预测代码如下:

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
	def __init__(self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers):
		super().__init__()
		# We take only convolutional layers from ResNet-50 model
		self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
		self.conv = nn.Conv2d(2048, hidden_dim, 1)
		self.transformer = nn.Transformer(hidden_dim, nheads,
		num_encoder_layers, num_decoder_layers)
		self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
		self.linear_bbox = nn.Linear(hidden_dim, 4)
		self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
		self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
		self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([
        	self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
        	self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),        
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
        self.query_pos.unsqueeze(1))
        return self.linear_class(h), self.linear_bbox(h).sigmoid()

detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值