1. 概述
这篇文章是比较早的文章,但是很多后序的工作都是源自于此,文章中提出了使用soft target(文章中说的是softmax层的输出)来实现大的模型到小的模型之间的知识迁移,从而提升小模型的性能。对于这里使用soft target而非hard target(如分类的类别目标),其原因是软目标能够提供更多可供训练的信息,而硬目标则会造成梯度上方差的减小。有了软目标的帮助小的模型能够更少的参数与更高的学习率。
对于大模型中软目标的值可能相差很大的情况,如果直接使用大模型中的软目标进行训练将会使得小模型过多关注软目标中的较大值,而忽略了其它的相似性信息,因而文章引入了temperature的概念使得软目标的分布区间变小,从而能够学习到更多的信息。
2. 温度参数
上文说到使用temperature参数加在softmax层上对输出的概率分布进行软化,其对应的数学表达为:
q
i
=
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
q_i=\frac{exp(z_i/T)}{\sum_{j}exp(z_j/T)}
qi=∑jexp(zj/T)exp(zi/T)
在小的模型训练的时候是加大参数
T
T
T的,而在训练完成之后是将T设置回正常值
T
=
1
T=1
T=1的。
在训练的过程中小目标的目标函数是由两部分组成的:
- 1)大模型用较大 T T T构造软目标与小模型对应 T T T参数输出的预测结果的交叉熵损失,这部分的损失由于参数 T T T的引入导致梯度下降 1 T 2 \frac{1}{T^2} T21,因而需要在计算梯度的时候进行补偿,乘以系数 T 2 T^2 T2;
- 2)小目标在参数 T = 1 T=1 T=1的时候计算输出与真实硬目标的交叉熵损失,对于这里的提到的两个损失是使用一个参数 α \alpha α来进行调整的,一般来讲带温度参数 T T T的损失是占主导的;
对于带温度参数
T
T
T的损失,其梯度计算为:
∂
C
∂
z
i
=
1
T
(
q
i
−
p
i
)
=
1
T
(
e
z
j
/
T
∑
j
e
z
j
/
T
−
e
v
j
/
T
∑
j
e
v
j
/
T
)
\frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{e^{z_j/T}}{\sum_je^{z_j/T}}-\frac{e^{v_j/T}}{\sum_je^{v_j/T}})
∂zi∂C=T1(qi−pi)=T1(∑jezj/Tezj/T−∑jevj/Tevj/T)
如果上式中的温度参数比较高那么,那么上面的梯度计算可以近似描述为:
∂
C
∂
z
i
≈
1
T
(
1
+
z
j
/
T
N
+
z
j
/
T
−
1
+
v
j
/
T
N
+
v
j
/
T
)
\frac{\partial C}{\partial z_i}\approx\frac{1}{T}(\frac{1+z_j/T}{N+z_j/T}-\frac{1+v_j/T}{N+v_j/T})
∂zi∂C≈T1(N+zj/T1+zj/T−N+vj/T1+vj/T)
再进一步假设logits是经过0均值的那么
∑
j
z
j
=
∑
j
v
j
=
0
\sum_jz_j=\sum_jv_j=0
∑jzj=∑jvj=0,那么上面的梯度计算就可以描述为:
∂
C
∂
z
i
≈
1
N
T
2
(
z
i
−
v
i
)
\frac{\partial C}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i)
∂zi∂C≈NT21(zi−vi)
3. 总结
这篇文章的原理比较简单,也很容易理解,后序还有很多人基于此在各个领域运用知识蒸馏方法获得小模型。这里推荐一份代码帮助熟悉知识蒸馏的运用:knowledge-distillation-pytorch