FasterVIT-Fast Vision Transformers with Hierarchical Attention
序言
该篇文章由NVIDIA团队于2023年6月9日发表于arxiv,文章链接,代码链接。
本文主要提出了一组综合性能超强的VIT的encoder架构(CNN-VIT混用),其中提出了一种新颖的自注意力机制HAT(Hierarchical Attention)。由图1所示,本文在与近几年SOTA网络比较,不同速度阶段(图片吞吐量)均取得精度最优Or不同精度阶段均取得速度最优,即综合性能最优。例:FasterVIT-0相较Visformer和EfficientNetV2有相当的精度,但是图片吞吐量多将近50%;FasterVIT-2相较ConvNeXt-4有相当的精度,吞吐量前者为3100左右,后者为不到500,相差5倍之多;而FasterVIT-4&5均取得了精度的SOTA,且速度相较其他大模型仍有优势。从不同模型尺度上看,FasterVIT在不同程度上取得了优异的效果,是毫无疑问的SOTA架构,且其即插即用的性质对模型研究者非常友好,友友们,用起来~~
图1:图像吞吐量和ImageNet1K Top-1评价(吞吐连数据为A100 Batchsize128运行)
引言
论文该部分除了介绍Vit的发展历史,感兴趣的友友可以着重看一下。其次主要介绍了在FasterViT架构的背景下层次注意力(HAT)的概念,HAT是一种新型的窗口注意力机制,它实现了局部特征信息和全局特征信息的有效结合。它引入了Swin Transformer中的本地窗口注意力机制,并提出了Carrier Tokens(CTs)模块,该模块可以总结整个本地窗口的特征并促进局部信息与全局信息的交换。基于网络架构中CNN与VIT的混用和窗口注意力机制线性化的计算资源的优势,FasterVIT在速度上相交其他VIT架构有天然优势。
网络模块与计算资源和内存之间联系的讨论
1.在多层级的视觉模型中,初始的网络层通常具有较大的空间维度和较少的通道数量。这意味着在这些层中,处理的数据量较大,可能需要更多的内存进行存储和操作。因此,这些初始层可能会受到内存限制的影响,计算速度可能会受到内存操作的瓶颈影响,因而称为"memory-bound"(内存受限)。面对这种情形,本文使用相关的CNN Block代替VIT Block有效地缓解了内存受限的情况。此外,本文还提出使用密集型的卷积(残差块)替代深度可分离卷积和稀疏卷积的操作,因为密集型的卷积主要涉及计算资源而不涉及在内存中的大量数据传输。而稀疏卷积的操作虽然只对部分位置进行卷积操作,但是其multi-操作则可能引入额外的数据开销,在内存资源受限的情况可能存在效率较低的情况。
2.一些无法矩阵形式表示的操作,例如非线性激活函数、池化和批归一化等,也会受到内存限制,对于前面的两层网络应该尽量减少使用这些操作。
3.随着网络层数的增加,其拥有了更多的通道数量,此时需要引入更具表达力的操作如层归一化和注意力机制等,这些操作可以提升网络的表达能力,并且相对于计算量来说,对计算吞吐量的影响较小(可以通过并行计算来处理多位置或多通道的关系)。
FasterVIT网络架构
网络架构采用多尺度的形式,在早期阶段的高分辨率层上运行卷积层,模型后半部分依赖于新的分层注意力层来在整个特征图上进行空间推理。其中,网络前半部分和下采样块使用了稠密卷积核,此外尽可能地避免了Squeeze(压缩)-Excitation(激励)操作和前两阶段的层归一化操作。注释:SE操作主要应用于卷积神经网络的通道维度上,是通过学习动态调整通道之间的权重,以增强有用特征的表示能力。Squeeze操作,通过全局平均池化操作,将每个通道的特征图转换为一个单一的数值。这相当于对每个通道进行了压缩操作,将其转化为一个全局的通道描述。Excitation操作,通过一个全连接的神经网络(通常包括一个或多个全连接层和激活函数),将上一步骤得到的全局通道描述映射为一组通道权重。这些通道权重用于对原始特征图中的每个通道进行加权操作,以增强对有用信息的响应。
图2:FasterVIT网络架构
网络输入数据预处理部分Stem
通过两个步长为2的3*3的连续卷积实现对输入图像的Overlapping patches embedding如图3所示,这些卷积层的作用是将输入图像中的像素映射到一个D维的嵌入空间,然后通过批归一化(提高模型训练稳定性)和ReLU函数(保持非线性特征)的处理。这一系列的操作可以初步地提取图像的特征,并为后续处理步骤提供更具表现力的输入。
图3:Overlapping patches embedding
下采样块
采用一层2D层归一化和步长为2的3*3卷积层。
卷积层
一二阶段的卷积块是经典的残差结构。
x
^
=
G
E
L
U
(
B
N
(
C
O
N
3
×
3
(
x
)
)
)
\hat{x}=GELU\left(BN\left({CON}_{3\times3}\left(x\right)\right)\right)
x^=GELU(BN(CON3×3(x)))
x
=
B
N
(
C
O
N
3
×
3
(
x
^
)
)
+
x
x=BN\left({CON}_{3\times3}\left(\hat{x}\right)\right)+x
x=BN(CON3×3(x^))+x
HAT(Hierarchical Attention)
图4为HAT的具体架构,从图中可以看出其主要有两个分支,一个分支为进行Local windows操作的局部特征提取分支,另一个分支为Carrier Tokens(CTs)操作的全局特征提取分支,CTs可以对整个局部窗口进行全局化特征提取。首先在CTs分支应用第一次注意力块来总结和传递全局信息,然后通过Concat操作对局部窗口的特征和全局特征对应融合,起到一一对应的作用。这样可以保障每个局部窗口只能与自身的全局特征相结合,并在局部窗口级别上进行后续的推理和预测。然后通过在融合的标记上执行第二次注意力机制,进一步促进了融合特征的信息交换和提取。然后对融合特征进行分离,并交替使用全局注意力机制和局部注意力机制,模型可以进一步地进行分组,窗口具有更高级别的特征信息,以便更好地捕捉不同层次的信息。
图4:Hierarchical Attention
首先将输入的大小为H×W×d的图像x窗口化得到x1,d为feature map的数量,k为窗口大小。
x
1
^
=
S
p
l
i
t
k
∗
k
(
x
)
\widehat{x_1}={\rm Split}_{k*k}(x)
x1
=Splitk∗k(x)
CTs的式子如下所示,首先通过卷积操作对输入进行位置编码,然后通过池化操作对分别的局部窗口进行全局特征提取。
x
c
^
=
C
o
n
v
3
∗
3
(
x
)
\widehat{x_c}={\rm Conv}_{3*3}\left(x\right)
xc
=Conv3∗3(x)
x
c
t
^
=
A
v
g
p
o
o
l
(
x
c
^
)
\widehat{x_{ct}}=Avgpool(\widehat{x_c})
xct
=Avgpool(xc
)
第一次注意力机制
x
c
t
^
=
x
c
t
^
+
γ
1
∗
M
H
S
A
(
L
N
(
x
c
t
^
)
)
\widehat{x_{ct}}=\widehat{x_{ct}}+\gamma_1*MHSA(LN(\widehat{x_{ct}}))
xct
=xct
+γ1∗MHSA(LN(xct
))
x
c
t
^
=
x
c
t
^
+
γ
2
∗
M
L
P
(
L
N
(
x
c
t
^
)
)
\widehat{x_{ct}}=\widehat{x_{ct}}+\gamma_2*MLP(LN(\widehat{x_{ct}}))
xct
=xct
+γ2∗MLP(LN(xct
))
融合特征
x
w
^
=
C
o
n
c
a
t
(
x
1
^
,
x
c
t
^
)
\widehat{x_w}=Concat(\widehat{x_1},\widehat{x_{ct}})
xw
=Concat(x1
,xct
)
第二次注意力机制
x
w
^
=
x
w
^
+
γ
1
∗
M
H
S
A
(
L
N
(
x
w
^
)
)
\widehat{x_w}=\widehat{x_w}+\gamma_1*MHSA(LN(\widehat{x_w}))
xw
=xw
+γ1∗MHSA(LN(xw
))
x
w
^
=
x
w
^
+
γ
2
∗
M
L
P
(
L
N
(
x
w
^
)
)
\widehat{x_w}=\widehat{x_w}+\gamma_2*MLP(LN(\widehat{x_w}))
xw
=xw
+γ2∗MLP(LN(xw
))
最后融合特征被进一步拆分,成为下一个HAT的输入。
x
1
,
^
x
c
t
^
=
S
p
l
i
t
(
x
w
^
)
\widehat{x_1,}\widehat{x_{ct}}=Split(\widehat{x_w})
x1,
xct
=Split(xw
)
此外,在网络最后,为了促进长距离信息交互,作者使用全局信息传播机制,在该过程中,通过将来自不同位置或区域的信息进行整合和传播,模型可以更好地捕捉全局上下文和相关性。
x
=
U
p
s
a
m
p
l
e
(
x
c
t
^
)
+
M
e
r
g
e
(
x
1
^
)
x=Upsample(\widehat{x_{ct}})+Merge(\widehat{x_1})
x=Upsample(xct
)+Merge(x1
)
值得注意的是,HAT结构中分别添加了一个绝对位置偏置和相对位置偏置。
直接向CTs和局部窗口标记添加绝对位置偏置可以在一定程度上解决"token position invariant"的问题。当模型在执行自注意力机制时,原始的注意力计算是基于标记之间的相对位置来进行的,而没有直接考虑标记的绝对位置。这可能导致模型在处理具有空间结构的数据(例如图像)时,无法充分利用标记在空间维度上的位置信息。通过添加绝对位置偏置,我们可以将标记的绝对位置信息引入到注意力计算中。具体地,通过使用一个多层感知器(MLP),我们将每个标记的绝对2D位置映射到特征维度。这样,模型可以通过绝对位置偏置对标记的位置信息进行编码,并在注意力计算过程中考虑到这些位置信息。通过引入绝对位置偏置,模型能够在执行自注意力时更好地捕捉标记在空间维度上的位置相关性。
对数空间相对位置偏置使用一个两层的多层感知器(MLP)来计算,在计算相对位置偏置时,通常会使用对数空间(log space)进行计算,这是因为对数空间可以更好地处理不同距离范围内的相对位置。通过使用对数空间,可以更平衡地对待远距离和近距离之间的相对位置关系。在注意力计算中,通过将对数空间相对位置偏置与注意力权重相乘,可以增强或调整注意力分配的方式。这意味着相对位置偏置可以影响模型对于不同标记之间的关注程度。
论文后面还有一系列参数比较和消融实验,详细可看原文。这是一个即插即用的模块,大家可以用起来优化自己的模型啦~~