本篇论文来自2021年NeurIPS,论文地址点这里
一. 介绍
知识蒸馏是一种著名的技术,用于学习具有竞争性精度的紧凑深度神经网络模型,其中较小的网络(学生)被训练来模拟较大网络(教师)的表示。知识蒸馏的普及主要是由于它的简单性和通用性;根据教师学习学生模型是很简单的,并且对两个模型的网络架构没有限制。大多数方法的主要目标是在预先定义和预先训练的教师网络下,如何有效地将暗知识转移到学生模型中。虽然知识蒸馏是一种很有前途和方便的方法,但它在准确性方面有时不能达到令人满意的性能。这在一定程度上是由于学生的模型能力相对于教师的模型能力过于有限,知识蒸馏算法是次优的。除此之外,我们认为师生特征的一致性对知识转移至关重要,教师不恰当的表征学习往往会导致知识蒸馏的次优性。
在本文中,提出可一种学生友好型的教师网络模型(SFTN),如下图(b)使得教师网络在训练过程中能够获得学生的相关网络结构,从而达到将知识很好进行转移的目的。
二. 方法
2.1 综述
传统的知识蒸馏方法试图在教师网络结构的基础上寻找学生网络教学的途径。教师网络的训练是基于对基本事实的损失,但目标不一定有利于对学生的知识蒸馏。与此相反,SFTN框架旨在提高从教师模型到学生模型的知识蒸馏的有效性。
模块化教师网络以及学生网络: 我们根据层的深度和特征图的大小将师生网络模块化为多个块。这是因为通常每隔3或4个块进行知识蒸馏,以准确提取和转移教师模型中的知识。上图展示了每个网络可以分为3个块,我们使用 { B T 1 , B T 2 , B T 3 } \{B^1_T,B^2_T,B^3_T\} {BT1,BT2,BT3}以及 { B S 1 , B S 2 , B S 3 } \{B^1_S,B^2_S,B^3_S\} {BS1,BS2,BS3}表示为学生网络和教师网络。
增加学生分支: SFTN将学生分支扩展为教师模式,用于两部分的联合训练。每个学生分支由一个教师网络特征变换层 T \mathcal{T} T和一个学生网络块组成。这个 T \mathcal{T} T主要是将 F T i \mathbf{F}_{\mathrm{T}}^i FTi的输出通道维度转换为 B S i + 1 B_{\mathrm{S}}^{i+1} BSi+1的输入通道维度。根据教师和学生网络的配置,转换需要增加或减少特征映射的大小。我们这里使用 3 × 3 3 \times 3 3×3的卷积和大小去减少 F T i \mathbf{F}_{\mathrm{T}}^i FTi的大小,或使用 4 × 4 4 \times 4 4×4的卷积核去增加 F T i \mathbf{F}_{\mathrm{T}}^i FTi的大小,也可以使用 1 × 1 1 \times 1 1×1的大小来保证不变性。转换到学生分支的特征被单独转发,以计算分支的logit。例如,如上图(a)所示,教师流中的 F T 1 \mathbf{F}_{\mathrm{T}}^1 FT1被转换为适合 B S 2 B_{\mathrm{S}}^{2} BS2, B S 2 B_{\mathrm{S}}^{2} BS2启动一个学生分支来派生 q R 1 \mathbf{q}_{\mathrm{R}}^1 qR1,而另一个学生分支则从 F T 2 \mathbf{F}_{\mathrm{T}}^2 FT2的转换特征开始。注意, F T 3 \mathbf{F}_{\mathrm{T}}^3 FT3在图中没有尾随的教师网络块,也没有相关的学生分支,因为它被直接用于计算主教师网络的逻辑输出。
SFTN的训练: 教师网络与教师中的单个块对应的多个学生分支一起训练,在这里我们将教师和学生分支之间的表示差异最小化。我们的损失函数由三项组成: 教师网络的损失 L T \mathcal{L}_T LT,学生分支中的KL损失 L R K L \mathcal{L}^{KL}_R LRKL,以及交叉熵损失 L R C E \mathcal{L}^{CE}_R LRCE。其中 L T \mathcal{L}_T LT最小化教师网络的输出 q T \mathbf{q}_{\mathrm{T}} qT和真实标签的差距, L R K L \mathcal{L}^{KL}_R LRKL则是确保教师输出和分支输出尽量相似,而最后 L R C E \mathcal{L}^{CE}_R LRCE则是保证分支也接近真实标签。
3.2 网络架构
SFTN由一个教师网络和多个学生分支机构组成。教师网络和学生网络被分为
N
N
N个块,其中教师网络中的一组块用
B
T
=
{
B
T
i
}
i
=
1
N
\mathbb{B}_{\mathrm{T}}=\left\{B_{\mathrm{T}}^i\right\}_{i=1}^N
BT={BTi}i=1N表示,学生网络中的一组块用
B
S
=
{
B
S
i
}
i
=
1
N
\mathbb{B}_{\mathrm{S}}=\left\{B_{\mathrm{S}}^i\right\}_{i=1}^N
BS={BSi}i=1N表示。注意,教师网络中的最后一个块没有相关的学生分支。
给定输入
x
\mathbf{x}
x,那么教师网络的输出
q
T
\mathbf{q}_{\mathrm{T}}
qT可以表示为:
q
T
(
x
;
τ
)
=
softmax
(
F
T
(
x
)
τ
)
(1)
\mathbf{q}_{\mathrm{T}}(\mathbf{x} ; \tau)=\operatorname{softmax}\left(\frac{\mathcal{F}_{\mathrm{T}}(\mathbf{x})}{\tau}\right) \tag1
qT(x;τ)=softmax(τFT(x))(1)
其中
F
T
(
x
)
\mathcal{F}_{\mathrm{T}}(\mathbf{x})
FT(x)表示为教师网络的逻辑输出,
τ
\tau
τ为温度超参数。另一方面,对于第
i
i
i个学生分支
q
R
i
\mathbf{q}^i_R
qRi,计算如下:
q
R
i
(
F
T
i
;
τ
)
=
softmax
(
F
S
i
(
T
i
(
F
T
i
)
)
τ
)
(2)
\mathbf{q}_{\mathrm{R}}^i\left(\mathbf{F}_{\mathrm{T}}^i ; \tau\right)=\operatorname{softmax}\left(\frac{\mathcal{F}_{\mathrm{S}}^i\left(\mathcal{T}^i\left(\mathbf{F}_{\mathrm{T}}^i\right)\right)}{\tau}\right) \tag2
qRi(FTi;τ)=softmax(τFSi(Ti(FTi)))(2)
2.3 损失函数
传统知识蒸馏框架中的教师网络仅用
L
T
\mathcal{L}_T
LT进行培训。然而,SFTN有额外的损失项,如2.1节所述的
L
R
K
L
\mathcal{L}^{KL}_R
LRKL和
L
R
C
E
\mathcal{L}^{CE}_R
LRCE。用
L
S
F
T
N
\mathcal{L}_{SFTN}
LSFTN表示为总损失,结果为:
L
S
F
T
N
=
λ
T
L
T
+
λ
R
K
L
L
R
K
L
+
λ
R
C
E
L
R
C
E
(3)
\mathcal{L}_{\mathrm{SFTN}}=\lambda_{\mathrm{T}} \mathcal{L}_{\mathrm{T}}+\lambda_{\mathrm{R}}^{\mathrm{KL}} \mathcal{L}_{\mathrm{R}}^{\mathrm{KL}}+\lambda_{\mathrm{R}}^{\mathrm{CE}} \mathcal{L}_{\mathrm{R}}^{\mathrm{CE}} \tag3
LSFTN=λTLT+λRKLLRKL+λRCELRCE(3)
其中
λ
\lambda
λ表示为各项的超参数。
接下来我们定义各项的损失。首先,
L
T
\mathcal{L}_T
LT使用交叉熵计算教师输出与真实标签的损失 :
L
T
=
CrossEntropy
(
q
T
,
y
)
(4)
\mathcal{L}_{\mathrm{T}}=\operatorname{CrossEntropy}\left(\mathbf{q}_{\mathrm{T}}, \mathbf{y}\right) \tag4
LT=CrossEntropy(qT,y)(4)
L
R
K
L
\mathcal{L}^{KL}_R
LRKL可以计算
N
−
1
N-1
N−1个分支的输出和教师输出的差距:
L
R
K
L
=
1
N
−
1
∑
i
=
1
N
−
1
K
L
(
q
~
R
i
∥
q
~
T
)
,
(5)
\mathcal{L}_{\mathrm{R}}^{\mathrm{KL}}=\frac{1}{N-1} \sum_{i=1}^{N-1} \mathrm{KL}\left(\tilde{\mathbf{q}}_{\mathrm{R}}^i \| \tilde{\mathbf{q}}_{\mathrm{T}}\right), \tag5
LRKL=N−11i=1∑N−1KL(q~Ri∥q~T),(5)
其中
q
~
R
i
\tilde{\mathbf{q}}_{\mathrm{R}}^i
q~Ri以及
q
~
T
\tilde{\mathbf{q}}_{\mathrm{T}}
q~T表示为使用稍大温度
τ
~
\tilde{\tau}
τ~进行软化后的结果。最后一个
L
R
C
E
\mathcal{L}^{CE}_R
LRCE则是计算各个分支和真实标签的损失:
L
R
C
E
=
1
N
−
1
∑
i
=
1
N
−
1
CrossEntropy
(
q
R
i
,
y
)
\mathcal{L}_{\mathrm{R}}^{\mathrm{CE}}=\frac{1}{N-1} \sum_{i=1}^{N-1} \operatorname{CrossEntropy}\left(\mathbf{q}_{\mathrm{R}}^i, \mathbf{y}\right)
LRCE=N−11i=1∑N−1CrossEntropy(qRi,y)
值得注意的是,对于
L
R
C
E
\mathcal{L}^{CE}_R
LRCE以及
L
T
\mathcal{L}_T
LT的温度都为1。
三. 总结
本次关于知识蒸馏的文章是在教师预训练阶段进行考虑的,通过提前引入学生的特征块,来方便教师网络更好地学习到有利于学生网路地暗知识。本次地代码没有找到,等之后我找到了我再更新上来。