AdaViT:用于高效图像识别的自适应视觉变换器
paper题目:AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
paper是复旦大学发表在CVPR 2022的工作
paper地址:链接
Abstract
建立在自注意力机制之上的视觉变换器最近在各种任务上都表现出了显著的性能。虽然取得了卓越的性能,但它们仍然需要相对密集的计算成本,随着patch、自注意力头和变换器块数量的增加而急剧扩大。在本文中,我们认为,由于图像之间的巨大变化,它们对patch之间长距离依赖关系的建模需求是不同的。为此,我们引入了AdaViT,这是一个自适应计算框架,它可以学习推导出关于在每个输入的基础上在整个骨干网中使用哪些patch、自注意力头和变换器块的使用策略,旨在提高视觉变换器的推理效率,使图像识别的准确性下降到最小。以端到端的方式与变换器主干网共同优化,一个轻量级的决策网络被附加到主干网上,以产生即时的决策。在ImageNet上进行的大量实验表明,与最先进的视觉变换器相比,我们的方法在效率上提高了2倍以上,而准确率只下降了0.8%,在不同的计算预算条件下实现了良好的效率/准确率折衷。我们进一步对学习到的使用决策进行定量和定性分析,并对视觉变换器中的冗余提供更多的见解。Code:链接。
1. Introduction
变换器[40]是各种自然语言处理任务的主流架构,自从视觉变换器(ViT)[7]的成功以来,在计算机视觉界引起了越来越多的研究兴趣。构建在自注意力机制之上的变换器能够有效地捕捉输入图像中像素/patch之间的长距离依赖关系,这可以说是它们在视觉任务中胜过标准CNN的主要原因之一,这些任务涵盖了从图像分类[4, 12, 20, 22, 38, 49, 56]到目标检测[3,5,43,44]、动作识别[9,23,58]等等。
最近关于视觉变换器的研究[4, 7, 38, 56]通常采用来自NLP的变换器[40]架构,并做了最小的变动。以类似于标记/词的一连串切片图像patch作为输入,变换器的主干由具有两个子层的堆叠积木组成,即一个自注意力层和一个前馈网络。为了确保模型能够共同关注来自不同表征子空间的信息,在每个模块中使用了多头注意力,而不是单一的注意力函数[40]。虽然这些基于自注意力的视觉转换器在ImageNet[6]等众多基准上的表现优于CNN,但竞争性的性能并不是免费的–带有多头的堆叠注意块的计算成本很大,而且随着patch数量的增加而进一步呈二次增长。
但是,为了正确地对图像进行分类,是否需要在整个网络中关注所有的patch?我们是否需要所有具有多个头部的自注意力块来寻找所有不同图像的关注点并建立潜在的依赖关系?毕竟,图像中存在很大的变化,如物体形状、物体大小、遮挡和背景复杂性。直观地说,对于包含杂乱背景或遮挡物体的复杂图像,需要更多的斑块和自注意力块,这需要足够的背景信息和对整个图像的理解,以便推断出它们的真实类别(如图1中的理发店),而(white stork)只需要少量的信息patch和注意头/块就足以对简单图像进行正确分类。
图1。我们的方法的概念概述。利用视觉变换器的冗余,AdaViT学会产生特定实例的使用策略,在这些策略上,patch、自注意力头和变换器块在整个网络中保持/激活,以实现高效的图像识别。较少的计算资源分配给简单样本(上),而更多的计算资源用于难样本(下),减少了总体计算成本,分类精度下降最小。两个图中都激活了绿色patch。
考虑到这一点,我们寻求开发一个自适应计算框架,学习在每个输入基础上使用哪些patch和激活哪些自注意力头/块。这样,对于简单的样本,可以舍弃冗余的输入补丁和骨干网层,对于复杂的样本,只使用全补丁的全模型,从而节约视觉变压器的计算成本。这是与近年来专注于静态网络架构设计的高效视觉变压器方法的正交互补方向[4,11,22,56]。
为此,我们引入了自适应视觉转换器(Adaptive Vision Transformer, AdaViT),这是一个端到端框架,它自适应地确定以输入图像为条件的视觉转换器的patch、头部和层的使用,以实现高效的图像分类。我们的框架学习派生实例特定推理策略:1)保留哪些patch;2)哪个自注意力头被激活;3)每幅图像要跳过哪些变换器块,以在降低分类精度的前提下提高推理效率。特别地,我们在骨干网的每个变换器块中插入一个轻量级的多头子网络(即决策网络),它学习预测整个网络中patch嵌入、自注意力头和块的使用的二进制决策。由于二元决策是不可微的,我们在训练时使用Gumbel-Softmax[26]使整个框架端到端可训练。决策网络与变换器骨干网联合优化,使用损耗度量所产生的使用策略的计算成本,以及正常的交叉熵损耗,这激励网络在保持分类精度的同时产生减少计算成本的策略。总体目标计算成本可以通过超参数 γ ∈ ( 0 , 1 ] \gamma \in(0,1] γ∈(0,1]来控制,对应于训练时以所有patch为输入的完整模型计算成本的百分比,使框架具有灵活性,能够适应不同计算预算的需要。
我们在ImageNet[6]上进行了大量的实验,验证了AdaViT的有效性,结果表明,我们的方法能够将视觉变换器的推理效率提高2×以上,而分类精度仅下降0.8%,与其他标准视觉变换器和cnn相比,实现了效率和精度之间的良好权衡。此外,我们对学习到的使用策略进行了定量和定性的分析,对视觉变换器的冗余提供了更多的直观和洞察。我们进一步展示了可视化,并演示了AdaViT学会了对复杂场景中相对困难的样本使用更多的计算,而对简单的以目标为中心的样本使用较少的计算。
2. Related Work
视觉变换器。受其在NLP任务中的巨大成功的启发,许多最近的研究探索了将Transformer[40]架构适应于多个计算机视觉任务[7-9,14,22,27,31,32,42,46,48,54,59]。继ViT[7]之后,人们提出了多种视觉变换器变体,以提高识别性能以及训练和推理效率。DeiT[38]结合了蒸馏策略来提高视觉转换器的训练效率,在JFT[35]这样的大规模数据集上不需要预训练就能超越标准cnn。其他方法如T2T-ViT [56], Swin Transformer [22], PVT[44]和CrossViT[4]寻求改善视觉变换器的网络架构。通过使用卷积层[20,53]、分层网络结构[22,23,44]、多尺度特征聚合[4,9]等方法,将二维cnn的优势引入变换器。视觉变换器在获得优越性能的同时,计算成本仍然很高,而且随着patch、自注意力头和变压器块数量的增加,计算成本会迅速上升。
高效的网络。通过设计有效的轻量级网络架构来提高cnn执行视觉任务的效率已经进行了广泛的研究[15,16,25,34,37,57]。为了匹配标准cnn的推理效率,最近的工作还探索开发高效的视觉转换器架构。T2TViT[56]提出使用深窄结构和token到token模块,实现了比ViT[7]更好的精度和更少的计算成本。LeViT[11]和Swin Transformer[22]开发了具有下采样的多级网络架构,获得了更好的推理效率。然而,这些方法对所有输入样本使用固定的网络体系结构,而不考虑patch中的冗余,对简单样本使用网络体系结构。我们的工作与这个方向正交,并专注于学习输入特定的策略,自适应地分配计算资源以节省计算量,同时降低精度。
自适应计算。自适应计算方法利用网络输入的巨大变化和网络架构的冗余,以提高实例特定推理策略的效率。**特别是,现有的cnn方法探索了改变输入样本[28,29,39,50,51,55],跳过网络层[10,18,41,45,52]和通道[1,21],早期使用多分类器结构[2,17,19]等。**最近也进行了一些尝试,以利用patch中的冗余的自适应推理策略来加速视觉转换器,例如,根据输入图像的条件生成关于patch大小为[47]和使用哪些patch[30,33]的策略。相比之下,我们利用了视觉变换器注意机制中的冗余,并提出了通过自适应选择自注意力头、变换器块和patch嵌入对输入样本保持/丢弃条件来提高效率。
3. Approach
我们提出了一种端到端自适应计算框架AdaViT来降低视觉变换器的计算成本。给定一个输入图像,AdaViT学习自适应派生策略,以输入图像为条件,在变换器主干中使用或激活patch、自注意力头和变换器块,鼓励使用更少的计算,同时保持分类精度。我们的方法的概述如图2所示。
图2。概述我们的方法。我们在视觉转换器主干的每个块之前插入一个轻量级的决策网络。给定一个输入图像,决策网络产生使用策略,patch、自注意力头和变换器块在整个主干上保持/激活。这些特定于实例的使用策略被鼓励以最小的精度下降来减少视觉转换器的总体计算成本。更多细节见文本。
3.1. Preliminaries
用于图像分类的视觉变换器[7, 38, 56]将图像的一连串切片作为输入,用堆叠的多头自注意力层和前馈网络来模拟它们的长程依赖关系。从形式上看,对于一个输入图像
I
\mathcal{I}
I,它首先被分割成一串固定大小的二维
X
=
[
x
1
,
x
2
,
…
,
x
N
]
\mathbf{X}=\left[\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_N\right]
X=[x1,x2,…,xN],其中
N
N
N是patch的数量(例如
N
=
14
×
14
N=14 \times 14
N=14×14)。然后,这些原始patch被映射到D维patch嵌入
Z
=
[
z
1
,
z
2
,
…
,
z
N
]
\mathbf{Z}=\left[\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_N\right]
Z=[z1,z2,…,zN]的线性层。一个可学习的嵌入
z
c
l
s
\mathbf{z}_{c l s}
zcls被称为类标记,它被附加到patch嵌入序列中,作为图像的表示。位置嵌入
E
pos
\mathbf{E}_{\text {pos }}
Epos 也被选择性地添加到patch嵌入中,以增加它们的位置信息。总而言之,第一个转化器块的输入是
Z
=
[
z
c
l
s
;
z
1
;
z
2
;
…
;
z
N
]
+
E
pos
(
1
)
\mathbf{Z}=\left[\mathbf{z}_{c l s} ; \mathbf{z}_1 ; \mathbf{z}_2 ; \ldots ; \mathbf{z}_N\right]+\mathbf{E}_{\text {pos }}(1)
Z=[zcls;z1;z2;…;zN]+Epos (1)
其中
z
∈
R
D
\mathbf{z} \in \mathbb{R}^D
z∈RD 和
E
pos
∈
R
(
N
+
1
)
×
D
\mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D}
Epos ∈R(N+1)×D。
与NLP中的变换器[40]类似,视觉变换器的主干网络由
L
L
L块组成,每块由多头自注意力层(MSA)和前馈网络(FFN)组成。具体来说,单头注意力的计算方法如下
Attn
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
(
2
)
\operatorname{Attn}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V(2)
Attn(Q,K,V)=softmax(dkQKT)V(2)
其中,
Q
,
K
,
V
Q, K, V
Q,K,V分别是广义上的查询、键和值矩阵,
d
k
d_k
dk是一个缩放系数。对于视觉变换器来说,
Q
,
K
,
V
Q, K, V
Q,K,V是由相同的输入,即patch嵌入投影而来。为了更有效地关注不同的表征子空间,多头自注意力将几个单头自注意力的输出串联起来,并用另一个参数矩阵投射:
head
i
,
l
=
Attn
(
Z
l
W
i
,
l
Q
,
Z
l
W
i
,
l
K
,
Z
l
W
i
,
l
V
)
(
3
)
MSA
(
Z
l
)
=
Concat
(
head
1
,
l
,
…
,
head
H
,
l
)
W
l
O
(
4
)
\begin{aligned} \operatorname{head}_{i, l} &=\operatorname{Attn}\left(\mathbf{Z}_l \mathbf{W}_{i, l}^Q, \mathbf{Z}_l \mathbf{W}_{i, l}^K, \mathbf{Z}_l \mathbf{W}_{i, l}^V\right)(3) \\ \operatorname{MSA}\left(\mathbf{Z}_l\right) &=\operatorname{Concat}\left(\operatorname{head}_{1, l}, \ldots, \operatorname{head}_{H, l}\right) \mathbf{W}_l^O (4) \end{aligned}
headi,lMSA(Zl)=Attn(ZlWi,lQ,ZlWi,lK,ZlWi,lV)(3)=Concat(head1,l,…,headH,l)WlO(4)
其中
W
i
,
l
Q
,
W
i
,
l
K
,
W
i
,
l
V
,
W
l
O
\mathbf{W}_{i, l}^Q, \mathbf{W}_{i, l}^K, \mathbf{W}_{i, l}^V, \mathbf{W}_l^O
Wi,lQ,Wi,lK,Wi,lV,WlO是第
l
l
l个变换器块的第
i
i
i个注意力头中的参数矩阵,
Z
l
\mathbf{Z}_l
Zl表示第
l
l
l个块的输入。MSA的输出随后被送入两层MLP的FFN,并产生变换器块
Z
l
+
1
\mathbf{Z}_{l+1}
Zl+1的输出。残差连接也被应用于MSA和FFN,如下所示。
Z
l
′
=
MSA
(
Z
l
)
+
Z
l
,
Z
l
+
1
=
FFN
(
Z
l
′
)
+
Z
l
′
(
5
)
\mathbf{Z}_l^{\prime}=\operatorname{MSA}\left(\mathbf{Z}_l\right)+\mathbf{Z}_l, \quad \mathbf{Z}_{l+1}=\operatorname{FFN}\left(\mathbf{Z}_l^{\prime}\right)+\mathbf{Z}_l^{\prime}(5)
Zl′=MSA(Zl)+Zl,Zl+1=FFN(Zl′)+Zl′(5)
最终的预测是由一个线性层产生的,以最后一个转换块
(
Z
L
0
)
\left(\mathbf{Z}_L^0\right)
(ZL0)的类标记作为输入。
3.2. Adaptive Vision Transformer
虽然大型视觉变换器模型已经取得了卓越的图像分类性能,但是当我们为了获得更高的精度而增加patch、注意力头和变换器块的数量时,计算成本会迅速增长。此外,计算成本高的 "一刀切 "网络对于许多简单的样本来说往往是矫枉过正的。为了解决这个问题,AdaViT学会了自适应地选择:1)使用哪些patch嵌入;2)激活MSA中的哪些自注意力头;以及3)在每个输入的基础上跳过哪些转化器块,以提高视觉转化器的推理效率。我们通过在每个转化器块之前插入一个轻量级的决策网络来实现这一目标,并对其进行训练以产生该块的三套使用策略。
Decision Network.
决策网络。第
l
l
l个区块的决策网络由三个线性层组成,参数
W
l
=
\mathbf{W}_l=
Wl=
{
W
l
p
,
W
l
h
,
W
l
b
}
\left\{\mathbf{W}_l^p, \mathbf{W}_l^h, \mathbf{W}_l^b\right\}
{Wlp,Wlh,Wlb},分别用于生成patch选择、注意力头选择和变换器块选择的计算使用策略。形式上,给定第
l
l
l个块的输入
Z
l
\mathbf{Z}_l
Zl,该块的使用策略矩阵计算如下。
(
m
l
p
,
m
l
h
,
m
l
b
)
=
(
W
l
p
,
W
l
h
,
W
l
b
)
Z
l
s.t.
m
l
p
∈
R
N
,
m
l
h
∈
R
H
,
m
l
b
∈
R
(
6
)
\begin{aligned} \left(\mathbf{m}_l^p, \mathbf{m}_l^h, \mathbf{m}_l^b\right)=\left(\mathbf{W}_l^p, \mathbf{W}_l^h, \mathbf{W}_l^b\right) \mathbf{Z}_l \\ \text { s.t. } \mathbf{m}_l^p \in \mathbb{R}^N, \mathbf{m}_l^h \in \mathbb{R}^H, \mathbf{m}_l^b \in \mathbb{R} \end{aligned}(6)
(mlp,mlh,mlb)=(Wlp,Wlh,Wlb)Zl s.t. mlp∈RN,mlh∈RH,mlb∈R(6)
其中,
N
N
N和
H
H
H表示一个变换器块中的patch和自注意力头的数量,
l
∈
[
1
,
L
]
l \in[1, L]
l∈[1,L]。
m
l
p
,
m
l
h
\mathbf{m}_l^p, \mathbf{m}_l^h
mlp,mlh和
m
l
b
\mathbf{m}_l^b
mlb的每个条目进一步传递给一个sigmoid函数,分别表示保留相应patch、注意力头和转换块的概率。第
l
l
l个决策网络共享前
l
−
1
l-1
l−1个变换器块的输出,使该框架比使用独立的决策网络更有效率。
由于决策是二进制的,在推理过程中,可以通过简单地对条目应用阈值来选择保留/丢弃的行动。然而,推导出不同样本的最佳阈值是具有挑战性的。为此,我们定义了随机变量 M l p , M l h , M l b \mathbf{M}_l^p, \mathbf{M}_l^h, \mathbf{M}_l^b Mlp,Mlh,Mlb,通过从 m l p , m l h \mathbf{m}_l^p, \mathbf{m}_l^h mlp,mlh和 m l b \mathbf{m}_l^b mlb中取样来进行决策。例如,当 M l , j p = 1 \mathbf{M}_{l, j}^p=1 Ml,jp=1时,第 l l l个区块中的第 j j j个patch嵌入被保留,而当 M l , j p = 0 \mathbf{M}_{l, j}^p=0 Ml,jp=0时被放弃。我们使用Gumbel-Softmax技巧[26]放松采样过程,使其在训练期间具有可区分性(参见第。3.3.)
patch选择。对于每个变换器块的输入,我们的目标是只保留最有信息的补丁嵌入,并丢弃其余的以加速推断。更正式地,对于第
l
l
l个块,如果
M
i
p
\mathbf{M}_i^p
Mip中的相应条目等于0:
Z
l
=
[
z
l
,
c
l
s
;
M
l
,
1
p
z
1
;
…
;
M
l
,
N
p
z
N
]
(
7
)
\mathbf{Z}_l=\left[\mathbf{z}_{l, c l s} ; \mathbf{M}_{l, 1}^p \mathbf{z}_1 ; \ldots ; \mathbf{M}_{l, N}^p \mathbf{z}_N\right](7)
Zl=[zl,cls;Ml,1pz1;…;Ml,NpzN](7)
类标记
z
l
,
c
l
s
\mathbf{z}_{l, c l s}
zl,cls总是被保留,因为它被用作整个图像的表示。
头选择。多头自注意力使模型能够共同关注表示的不同子空间[40],并且在大多数(如果不是全部)视觉变换变体[4,7,22,38,56]中采用。这样的多头部设计对于建模图像中潜在的长距离依赖关系至关重要,尤其是那些具有复杂场景和杂乱背景的图像,但较少的注意力可能足以在简单的图像中寻找要注意的地方。考虑到这一点,我们探索了如何根据输入图像自适应地放下注意力,以便更快地进行推理。与patch选择类似,激活或停用某个注意头的决定由
M
l
h
\mathbf{M}_l^h
Mlh中的相应条目决定。注意头的“停用”可以以不同的方式进行实例化。在我们的框架中,我们探索了两种头部选择方法,即部分失活和完全失活。对于部分失活,注意softmax输出,如等式所示。2替换为预定义的值,如
(
N
+
1
,
N
+
1
)
(N+1, N+1)
(N+1,N+1)单位矩阵
1
\mathbb{1}
1,这样可以节省计算注意力图的成本。然后,第
l
l
l个区块第
i
i
i个头部的注意力计算为:
Attn
(
Q
,
K
,
V
)
l
,
i
=
{
softmax
(
Q
K
T
d
k
)
⋅
V
if
M
l
,
i
h
=
1
1
⋅
V
if
M
l
,
i
h
=
0
\operatorname{Attn}(Q, K, V)_{l, i}= \begin{cases}\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) \cdot V & \text { if } \mathbf{M}_{l, i}^h=1 \\ \mathbb{1} \cdot V & \text { if } \mathbf{M}_{l, i}^h=0\end{cases}
Attn(Q,K,V)l,i={softmax(dkQKT)⋅V1⋅V if Ml,ih=1 if Ml,ih=0
为了完全失活,将整个头部从多头自注意力层中移除,MSA输出的嵌入尺寸相应减小:
MSA
(
Z
l
)
l
,
i
=
concat
(
[
head
l
,
i
:
1
→
H
if
M
l
,
i
h
=
1
]
)
W
l
O
′
\operatorname{MSA}\left(\mathbf{Z}_l\right)_{l, i}=\operatorname{concat}\left(\left[\operatorname{head}_{l, i: 1 \rightarrow H} \text { if } \mathbf{M}_{l, i}^h=1\right]\right) \mathbf{W}_l^{O^{\prime}}
MSA(Zl)l,i=concat([headl,i:1→H if Ml,ih=1])WlO′
实际上,当相同比例的头被停用时,与部分停用相比,完全停用可以节省更多的计算,但随着嵌入大小的动态操作,可能会产生更多的分类错误。
块选择。除了patch选择和头选择之外,由于整个网络中的残差连接,当变换器块是冗余的时,也可以有利地将其完全跳过。为了增加层跳过的灵活性,我们将块使用策略矩阵
m
l
b
\mathbf{m}_l^b
mlb的维度从1增加到2,使得每个变换器块中的两个子层(MSA和FFN)能够被单独控制。于是,等式5变成:
Z
l
′
=
M
l
,
0
b
⋅
MSA
(
Z
l
)
+
Z
l
Z
l
+
1
=
M
l
,
1
b
⋅
FEN
(
Z
l
′
)
+
Z
l
′
(
8
)
\begin{aligned} \mathbf{Z}_l^{\prime} &=\mathbf{M}_{l, 0}^b \cdot \operatorname{MSA}\left(\mathbf{Z}_l\right)+\mathbf{Z}_l \\ \mathbf{Z}_{l+1} &=\mathbf{M}_{l, 1}^b \cdot \operatorname{FEN}\left(\mathbf{Z}_l^{\prime}\right)+\mathbf{Z}_l^{\prime} \end{aligned}(8)
Zl′Zl+1=Ml,0b⋅MSA(Zl)+Zl=Ml,1b⋅FEN(Zl′)+Zl′(8)
总之,给定每个变换器块的输入,决策网络产生该块的使用策略,然后通过应用了决策的块转发输入。最后,获得来自最后一层的分类预测和所有块
M
=
{
M
l
p
,
M
l
h
,
M
l
b
\mathbf{M}=\left\{\mathbf{M}_l^p, \mathbf{M}_l^h, \mathbf{M}_l^b\right.
M={Mlp,Mlh,Mlb, for
l
:
1
→
L
}
\left.l: 1 \rightarrow L\right\}
l:1→L}的判决。
3.3. Objective Function
由于我们的目标是在精度下降最小的情况下降低视觉变压器的整体计算成本,因此AdaViT的目标函数旨在激励正确的分类,同时减少计算量。特别地,使用损失和交叉熵损失被用于联合优化框架。给定具有标签
y
\mathbf{y}
y的输入图像
I
I
I,最终预测由具有参数
θ
\boldsymbol{\theta}
θ的变换器
F
\mathbf{F}
F产生,并且交叉熵损失计算如下:
L
c
e
=
−
y
log
(
F
(
I
;
θ
)
)
(
9
)
L_{c e}=-\mathbf{y} \log (\mathbf{F}(I ; \boldsymbol{\theta}))(9)
Lce=−ylog(F(I;θ))(9)
虽然可以通过在推断期间应用阈值来容易地获得关于是否保留/丢弃patch/头/块的二元决策,但是确定最佳阈值是具有挑战性的。此外,这种操作在训练期间是不可微的,因此使得决策网络的优化具有挑战性。一种常见的解决方案是求助于强化学习,并使用策略梯度方法优化网络[36],但由于随离散变量的维度而变化的较大方差,收敛可能会很慢[26,36]。为此,我们使用Gumbel-Softmax技巧[26]来放松采样过程并使其可微分。形式上,
m
\mathbf{m}
m的第
i
i
i个分量处的判定是这样推导出来的:
M
i
,
k
=
exp
(
log
(
m
i
,
k
+
G
i
,
k
)
/
τ
)
∑
j
=
1
K
exp
(
log
(
m
i
,
j
+
G
i
,
j
)
/
τ
)
for
k
=
1
,
2
,
…
,
K
(
10
)
\begin{array}{r} \mathbf{M}_{i, k}=\frac{\exp \left(\log \left(\mathbf{m}_{i, k}+G_{i, k}\right) / \tau\right)}{\sum_{j=1}^K \exp \left(\log \left(\mathbf{m}_{i, j}+G_{i, j}\right) / \tau\right)} \\ \text { for } k=1,2, \ldots, K \end{array}(10)
Mi,k=∑j=1Kexp(log(mi,j+Gi,j)/τ)exp(log(mi,k+Gi,k)/τ) for k=1,2,…,K(10)
其中
K
K
K是类别的总数(在我们的例子中,对于二进制决策,
K
=
2
K=2
K=2),
G
i
=
−
log
(
−
log
(
U
i
)
)
G_i=-\log \left(-\log \left(U_i\right)\right)
Gi=−log(−log(Ui))是Gumbel分布,其中
U
i
U_i
Ui是从Uniform
(
0
,
1
)
(0,1)
(0,1)中采样的,一个i.i.d均匀分布。温度
τ
\tau
τ用于控制
M
i
\mathbf{M}_i
Mi的平滑度。
为了鼓励减少总体计算成本,我们设计了如下的使用损失:
L
usage
=
(
1
D
p
∑
d
=
1
D
p
M
d
p
−
γ
p
)
2
+
(
1
D
h
∑
d
=
1
D
h
M
d
h
−
γ
h
)
2
+
(
1
D
b
∑
d
=
1
D
b
M
d
b
−
γ
b
)
2
(
11
)
\begin{aligned} L_{\text {usage }} &=\left(\frac{1}{D_p} \sum_{d=1}^{D_p} \mathbf{M}_d^p-\gamma_p\right)^2+\left(\frac{1}{D_h} \sum_{d=1}^{D_h} \mathbf{M}_d^h-\gamma_h\right)^2 \\ &+\left(\frac{1}{D_b} \sum_{d=1}^{D_b} \mathbf{M}_d^b-\gamma_b\right)^2 \end{aligned}(11)
Lusage =⎝⎛Dp1d=1∑DpMdp−γp⎠⎞2+(Dh1d=1∑DhMdh−γh)2+(Db1d=1∑DbMdb−γb)2(11)
其中
D
p
=
L
×
N
,
D
h
=
L
×
H
,
D
b
=
L
×
2
D_p=L \times N, D_h=L \times H, D_b=L \times 2
Dp=L×N,Dh=L×H,Db=L×2
这里
D
p
,
D
h
,
D
b
D_p, D_h, D_b
Dp,Dh,Db表示来自patch/head/block选择决策网络的平坦概率向量的大小,即分别为整个变压器的patch、head和block的总数。超参数
γ
p
,
γ
h
,
γ
b
∈
(
0
,
1
]
\gamma_p, \gamma_h, \gamma_b \in(0,1]
γp,γh,γb∈(0,1]表示以块/头/块保持的百分比表示的目标计算预算。
min
θ
,
W
L
=
L
c
e
+
L
usage
(
12
)
\min _{\boldsymbol{\theta}, \mathbf{W}} L=L_{c e}+L_{\text {usage }}(12)
θ,WminL=Lce+Lusage (12)
最后,结合两个损失函数,以端到端方式最小化,如Eqn. 12。