ResT:用于图像识别的高效Transformer

📝论文下载地址

  [论文地址]

🔨代码下载地址

  [GitHub-official-Pytorch]
  [GitHub-unofficial-Pytorch]

👨‍🎓论文作者

Qinglong Zhang, Yubin Yang

📦模型讲解

[背景介绍]

  Transformer相关背景见[Transformer]。Transformer在处理计算机视觉任务取得不错的效果,但是始终没能超过卷积神经网络的相关方法。例如,针对图像识别任务的ViT网络,见下图。

  另外,针对图像目标检测任务的DETR网络,见下图。

  目前大多数的针对CV的Transformer方法都是首先将输入图像拆分为补丁,补丁的处理方式与 NLP 应用程序中相同。然后使用几个自监督层进行全局的信息交流,提取特征进行分类。直到谷歌提出NesT的模型,才使得Transformer方法在CV领域成为SOTA。

[模型解读]

  在本节中,作者提出了一种高效的ResT模型,将Transformer用于图像识别任务,并拥有不错的速度。

[总体结构]

  ResT的总体结构与ResNet相似,都是有4个stage。每一个stage中有4部分组成,如下图所示。

[标准Transformer]

  标准的Transformer主要包含两个结构,MSA与FFN,每个结构都会有跨层链接,归一化采用LayerNorm的方法。假设输入Transformer的特征 x ∈ R n × d m x\in \mathbb{R}^{n \times d_m} xRn×dm n n n为空间维度, d m d_m dm为通道维度。对于标准的Transformer其输出为:
y = x ′ + F F N ( L N ( x ′ ) ) , a n d x ′ = x + M S A ( L N ( x ) ) y=x'+FFN(LN(x')),\quad and \quad x'=x+MSA(LN(x)) y=x+FFN(LN(x)),andx=x+MSA(LN(x))
  假设MSA中仅有一个head,那么 d k = d m d_k=d_m dk=dm,其中MSA 的计算成本为 O ( 2 d m n 2 + 4 d m 2 n ) O(2d_mn^2+4d_m^2n) O(2dmn2+4dm2n)。在FFN中,一般 d f = 4 d m d_f=4d_m df=4dm,FFN计算将花费 8 n d m 2 8nd_m^2 8ndm2

[EMSA]

  对于标准的Transformer,MSA有两个缺点:
(1)MSA尺度的计算是根据输入维度 d m d_m dm n n n的二次方计算的,导致训练和推理的开销巨大;
(2)MSA中的每一个head只负责输入的一个子集,这可能会损害网络的性能,特别是当每个子集中的通道维度 d k d_k dk太低时,使得Q和K的点积不能够构成一个信息匹配函数。
  为了解决这些问题,作者提出了一种高效的多头自注意力模块,如下图所示。

(1)与MSA类似,EMSA首先获得查询向量Q。
(2)为了减少存储占用,将二维输入 x ∈ R n × d m x\in \mathbb{R}^{n \times d_m} xRn×dm重构为三维 x ^ ∈ R d m × h × w \hat x \in \mathbb{R}^{d_m \times h \times w} x^Rdm×h×w,之后送入卷积神经网络进行s倍的下采样。这里作者设置 s = 8 / k s=8/k s=8/k。卷积的步长为s,卷积核大小为为s+1,padding为s/2。
(3)在空间维度缩小后 x ^ ∈ R d m × h / s × w / s \hat x \in \mathbb{R}^{d_m \times h/s \times w/s} x^Rdm×h/s×w/s再重构为二维向量 x ^ ∈ R n ′ × d m , n ′ = h / s × w / s \hat x \in \mathbb{R}^{n' \times d_m},n'=h/s \times w/s x^Rn×dm,n=h/s×w/s,通过 x ^ \hat x x^获得键向量K与值向量V。
(4)之后根据下式进行自注意运算。
E M S A ( Q , K , V ) = I N ( S o f t m a x ( C o n v ( Q K T d k ) ) ) V EMSA(\bm{Q},\bm{K},\bm{V})=IN(Softmax(Conv(\frac{\bm{Q}\bm{K}^T}{\sqrt{d_k}})))\bm{V} EMSA(Q,K,V)=IN(Softmax(Conv(dk QKT)))V
其中Conv()为标准的1×1卷积操作,IN()表示Instance Normalization。
(5)最后,将所有head特征串联生成最后的特征。
  相对标准的MSA,EMSE的计算复杂度为 O ( 2 d m n 2 s 2 + 2 d m 2 n ( 1 + 1 s 2 ) + d m n ( s + 1 ) 2 s 2 + k 2 n 2 s 2 ) O(\frac{2d_mn^2}{s^2}+2d_m^2n(1+\frac{1}{s^2})+d_mn\frac{(s+1)^2}{s^2}+\frac{k^2n^2}{s^2}) O(s22dmn2+2dm2n(1+s21)+dmns2(s+1)2+s2k2n2),假设s>1则会低于MSA。
此外,在 EMSA 之后添加 FFN 以进行特征转换和非线性。每个高效 Transformer 模块的输出为:
y = x ′ + F F N ( L N ( x ′ ) ) , a n d x ′ = x + E M S A ( L N ( x ) ) y=x'+FFN(LN(x')),\quad and\quad x'=x+EMSA(LN(x)) y=x+FFN(LN(x)),andx=x+EMSA(LN(x))

[Patch Embedding/切片嵌入]

  以ViT为例,网络将输入图片 x ∈ R 3 × h × w x\in \mathbb R^{3 \times h \times w} xR3×h×w裁剪为多个 s × s s \times s s×s的Patch,之后进行嵌入编码生成 x ∈ R n × c x \in \mathbb R^{n \times c} xRn×c n = h w / p 2 n=hw/p^2 n=hw/p2,c为编码的维度。在作者提出的ResT网络中,首先输入stem模块进行4×4的下采样,作者采用的是三个卷积层进行放缩,步长分别为2、1、2。前两层使用BatchNorm与ReLU激活函数。
  在stage2、3、4的Patch Embedding中,会对特征空间维度进行4倍下采样,通道维度变为2倍,通过步长为2、填充为1的标准3×3卷积来完成。

[Position Encoding/位置编码]

  以ViT为例,将一组可学习的参数 θ \theta θ作为位置编码。即编码后输出为: x ^ = x + θ \hat x =x+\theta x^=x+θ  由此导致输入网络的特征尺寸不能变化。
  本文作者假设参数 θ \theta θ与输入 x x x有关,即 θ = G L ( x ) \theta=GL(x) θ=GL(x) x ^ = x + G L ( x ) \hat x =x+GL(x) x^=x+GL(x)   θ \theta θ也可以通过更灵活的注意力机制获得。作者使用逐像素注意力(PA)模块来进行位置编码。PA 使用3×3卷积来获得像素级权重,然后通过sigmoid函数 σ ( ⋅ ) \sigma(\sdot ) σ()进行缩放。 PA模块的位置编码可以表示为: x ^ = P A ( x ) = x × σ ( D W C o n v ( x ) ) \hat x =PA(x)=x\times \sigma(DWConv(x)) x^=PA(x)=x×σ(DWConv(x))  将PA位置编码与Patch Embedding结合结构如下。

[Classification Head/分类头]

  分类头采用全局平均池化与全连接层。网络的结构见下表。Conv-k_c_s表示卷积层,卷积核大小为k,输出通道数为c,步长为s。MLP_c表示中间特征维度为4c,输入输出特征维度为c。MCSA_n_r表示EMSA,n为head数目,r为下采样倍数。

[结果分析]

[图像分类-ImageNet-1k]
[目标检测-COCO]
[实例分割-COCO]
[消融实验]
[模型类别]
[EMSA结构]
[位置编码]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值