文章目录
1. Title
- Twins: Revisiting the Design of Spatial Attention in Vision Transformers
- https://github.com/Meituan-AutoML/Twins
2. Summary
本文整体设计思路与之前的Vision Transformer Backbone一致,如何获取多尺度特征(PVT)以及如何降低Self-Attention的计算复杂度(SwinTransformer)。
第一点PVT的方法就是一种比较合适的方法,各个Stage将Embedding进行合并,从而减少Embedding的数目,以实现不同Stage具有不同分辨率的金字塔结构。
第二点SwinTransformer通过在Local Window内部计算Self-Attention的方式降低了计算复杂度,同时提出了一种比较复杂的Shift-Window的方式去捕获各个窗口之间的依赖关系。
作者认为SwinTransformer的Shift-Window方法较为复杂,且现代深度学习框架支持性不够好,因此,本文提出了一种更为简单的方法去实现。
整体思路可以认为是PVT+SwinTransformer的结合:在局部窗口内部计算Self-Attention(SwinTransformer),同时对每个窗口内部的特征进行压缩,然后再使用一个全局Attention机制去捕获各个窗口的关系(PVT)。
个人觉得Twins这种思路更加简单,而且也比较高效,SwinTransformer的实现包含了作者大量精巧的设计(阅读过SwinTransformer的代码应该都会觉得其实现是十分巧妙且美妙的),实际上也是比较复杂的。大道至简,还是Twins的设计更符合认知一点,不过其性能还有待验证。
3. Problem Statement
相较于CNN来说,Transformer由于其能高效地捕获远距离依赖的特性,近期在计算机视觉领域也引领了一波潮流。Transformer主要是依靠Self-Attention去捕获各个token之间的关系,但是这种Global Self-Attention的计算复杂度太高 O ( N 2 ) O(N^2) O(N2),不利于在token数目较多的密集检测任务(分割、检测)中使用。
基于以上考虑,目前主流有两种应对方法:
- 一种是以SwinTransformer为代表的Locally-Grouped Self-Attention。其在不重叠的窗口内计算Self-Attention,当窗口大小固定时,整体的计算复杂度将下降至 O ( N ) O(N) O(N),然后再通过其他方法例如SwinTransformer中的Shift-Window的方法去实现窗口间的互动。但这种方法的缺点在于窗口的大小会不一致,不利于现代深度学习框架的优化和加速。
- 一种是以PVT为代表的Sub-Sampled Version Self-Attention。其在计算Self-Attention前,会先对QKT Token进行下采样,虽然计算复杂度还是 O ( N 2 ) O(N^2) O(N2),但是在实际过程中也已经可以接受。
总结下来,目前的VisionTransformer的关键设计在于Spatial Attention的设计。
因此,本文将重点在Spatial Attention的设计上,期望提出一个高效同时简单的Spatial Attention方法。
4. Method(s)
作者有两个发现:
- PVT中的Global Sub-Sample Attention是十分高效的,当配合上合适的Positional Encodings(Conditional Positional Encoding)时,其能取得媲美甚至超过目前SOTA的Transformer结构。
- 更进一步,基于Separable Depthwise Convolution的思想,本文提出了一个Spatially Separable Self-Attention(SSSA)。该模块仅包含矩阵乘法,在现代深度学习框架下能够得到优化和加速。
4.1 Twins-PCPVT
PVT通过逐步融合各个Patch的方式,形成了一种多尺度的结构,使得其更适合用于密集预测任务例如目标检测或者是语义分割,其继承了ViT和DeiT的Learnable Positional Encoding的设计,所有的Layer均直接使用Global Attention机制,并通过Spatial Reduction的方式去降低计算复杂度。
作者通过实验发现,PVT与SwinTransformer的性能差异主要来自于PVT没有采用一个合适的Positional Encoding方式,通过采用Conditional Positional Encoding(CPE)去替换PVT中的PE,PVT即可获得与当前最好的SwinTransformer相近的性能。关于CPE的具体介绍,可以参见我的另一篇博客:Conditional Positional Encodings for Vision Transformers。
4.2 Twins-SVT
通过提出的Spatially Separable Self-Attention(SSSA)去缓解Self-Attention的计算复杂度过高的问题。SSSA由两个部分组成:Locally-Grouped Self-Attention(LSA)和Global Sub-Sampled Attention(GSA)。
4.2.1 Locally-Grouped Self-Attention(LSA)
首先将2D feature map划分为多个Sub-Windows,并仅在Window内部进行Self-Attention计算,计算量会大大减少,由 ( H 2 W 2 d ) \left(H^{2} W^{2}d\right) (H2W2d)下降至 O ( k 1 k 2 H W d ) \mathcal{O}\left(k_{1} k_{2} H W d\right) O(k1k2HWd),其中 k 1 = H m , k 2 = W n k_{1}=\frac{H}{m}, k_{2}=\frac{W}{n} k1=mH,k2=nW,当 k 1 , k 2 k_1,k_2 k1,k2固定时,计算复杂度将仅与 H W HW HW呈线性关系。
4.2.2 Global Sub-Sampled Attention(GSA)
LSA缺乏各个Window之间的信息交互,比较简单的一个方法是,在LSA后面再接一个Global Self-Attention Layer,这种方法在实验中被证明也是有效的,但是其计算复杂度会较高: O ( H 2 W 2 d ) \mathcal{O}\left(H^{2} W^{2} d\right) O(H2W2d)。
另一个思路是,将每个Window提取一个维度较低的特征作为各个window的表征,然后基于这个表征再去与各个window进行交互,相当于Self-Attention中的Key的作用,这样一来,计算复杂度会下降至:
O
(
m
n
H
W
d
)
=
O
(
H
2
W
2
d
k
1
k
2
)
\mathcal{O}(m n H W d)=\mathcal{O}\left(\frac{H^{2} W^{2} d}{k_{1} k_{2}}\right)
O(mnHWd)=O(k1k2H2W2d)。
这种方法实际上相当于对feature map进行下采样,因此,被命名为Global Sub-Sampled Attention。
综合使用LSA和GSA,可以取得类似于Separable Convolution(Depth-wise+Point-wise)的效果,整体的计算复杂度为:
O
(
H
2
W
2
d
k
1
k
2
+
k
1
k
2
H
W
d
)
\mathcal{O}\left(\frac{H^{2} W^{2} d}{k_{1} k_{2}}+k_{1} k_{2} H W d\right)
O(k1k2H2W2d+k1k2HWd)。同时有:
H
2
W
2
d
k
1
k
2
+
k
1
k
2
H
W
d
≥
2
H
W
d
H
W
\frac{H^{2} W^{2} d}{k_{1} k_{2}}+k_{1} k_{2} H W d \geq 2 H W d \sqrt{H W}
k1k2H2W2d+k1k2HWd≥2HWdHW,当且仅当
k
1
⋅
k
2
=
H
W
k_{1} \cdot k_{2}=\sqrt{H W}
k1⋅k2=HW。
考虑到分类任务中,
H
=
W
=
224
H=W=224
H=W=224是比较常规的设置,同时,不是一般性使用方形框,则有
k
1
=
k
2
k_1=k_2
k1=k2,第一个stage的feature map大小为56,可得
k
1
=
k
2
=
56
=
7
k_1=k_2=\sqrt{56}=7
k1=k2=56=7。
当然可以针对各个Stage去设定其窗口大小,不过为了简单性,所有的
k
k
k均设置为7。
整个Transformer Block可以被表示为:
z
^
i
j
l
=
LSA
(
LayerNorm
(
z
i
j
l
−
1
)
)
+
z
i
j
l
−
1
z
i
j
l
=
FFN
(
LayerNorm
(
z
^
i
j
l
)
)
+
z
^
i
j
l
z
^
l
+
1
=
GSA
(
LayerNorm
(
z
l
)
)
+
z
l
z
l
+
1
=
FFN
(
LayerNorm
(
z
^
l
+
1
)
)
+
z
^
l
+
1
,
i
∈
{
1
,
2
,
…
,
m
}
,
j
∈
{
1
,
2
,
…
,
n
}
\begin{array}{l} \hat{\mathbf{z}}_{i j}^{l}=\text { LSA }\left(\text { LayerNorm }\left(\mathbf{z}_{i j}^{l-1}\right)\right)+\mathbf{z}_{i j}^{l-1} \\ \mathbf{z}_{i j}^{l}=\text { FFN }\left(\text { LayerNorm }\left(\hat{\mathbf{z}}_{i j}^{l}\right)\right)+\hat{\mathbf{z}}_{i j}^{l} \\ \hat{\mathbf{z}}^{l+1}=\text { GSA }\left(\text { LayerNorm }\left(\mathbf{z}^{l}\right)\right)+\mathbf{z}^{l} \\ \mathbf{z}^{l+1}=\text { FFN }\left(\text { LayerNorm }\left(\hat{\mathbf{z}}^{l+1}\right)\right)+\hat{\mathbf{z}}^{l+1}, \\ i \in\{1,2, \ldots, m\}, j \in\{1,2, \ldots, n\} \end{array}
z^ijl= LSA ( LayerNorm (zijl−1))+zijl−1zijl= FFN ( LayerNorm (z^ijl))+z^ijlz^l+1= GSA ( LayerNorm (zl))+zlzl+1= FFN ( LayerNorm (z^l+1))+z^l+1,i∈{1,2,…,m},j∈{1,2,…,n}同时在每个Stage的第一个Block中会引入CPVT中的的PEG对位置信息进行编码。
4.3 Model Variants
5. Evaluation
5.1 Experiments
通过以上实验结果可以看出,Twins系列在各个任务上均取得了与SwinTransformer相当甚至是超过的水平,不过相比较而言,除了Small,Twins的模型参数比SwinTransformer系列都稍微大一点,而且运行速度似乎也没有明显优势。
5.2 Ablation Studies
6. Conclusion
本文提出了两种Vision Transformer Backbone,同时适用于图片级的分类任务或是其他密集预测任务,并且在分类、分割、检测等多个任务上,均取得了新的SOTA。