这篇文章是Hinton的大作,采用了一个很特别的方式(知识蒸馏,KT, Knowledge distillation)来进行模型压缩。言而总之,就是预先训练一个大模型去调教小模型,使得小模型更够在应用端更好的跑起来。
文章地址:Distilling the Knowledge in a Neural Network, Hinton et al, 2015
为什么要采用知识蒸馏呢
作者认为对于分类模型来说,标签label提供的信息比较少,分类模型的目的就是将输入向量进行映射到类别中的一个某一点上。如果此时有一个训练好的大模型,也称之教师模型,可以采用知识蒸馏的算法模型将模型的输出(softmax输出),拿来作为小模型的另一个标签。此时,教师模型的输出可以作为soft target(区别于原始的label,hard target),此时相当于增加了分类模型的监督信息。从而使得小模型更好的泛化。
对于这一部分,作者也进行了一些解释。比如输入一张宝马车的图片,虽然不会分类成垃圾车(大概给这一类别的概率有1e-6),但是给胡萝卜的概率更低(1e-7)。因此这种soft target也能提供出包含分类模型的类间信息及类内信息。bmw车更像垃圾车,而不是胡萝卜。
这里附上知乎上各位大神的理解:
作者:Naiyan Wang
链接:https://www.zhihu.com/question/50519680/answer/136363665
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
这个问题真的很有意思,我也曾经花了很多时间思考这个问题,可以有一些思路可以和大家分享。一句话简言之,Knowledge
Distill是一种简单弥补分类问题监督信号不足的办法。传统的分类问题,模型的目标是将输入的特征映射到输出空间的一个点上,例如在著名的Imagenet比赛中,就是要将所有可能的输入图片映射到输出空间的1000个点上。这么做的话这1000个点中的每一个点是一个one
hot编码的类别信息。这样一个label能提供的监督信息只有log(class)这么多bit。然而在KD中,我们可以使用teacher
model对于每个样本输出一个连续的label分布,这样可以利用的监督信息就远比one
hot的多了。另外一个角度的理解,大家可以想象如果只有label这样的一个目标的话,那么这个模型的目标就是把训练样本中每一类的样本强制映射到同一个点上,这样其实对于训练很有帮助的类内variance和类间distance就损失掉了。然而使用teacher
model的输出可以恢复出这方面的信息。具体的举例就像是paper中讲的,
猫和狗的距离比猫和桌子要近,同时如果一个动物确实长得像猫又像狗,那么它是可以给两类都提供监督。综上所述,KD的核心思想在于"打散"原来压缩到了一个点的监督信息,让student模型的输出尽量match
teacher模型的输出分布。其实要达到这个目标其实不一定使用teacher
model,在数据标注或者采集的时候本身保留的不确定信息也可以帮助模型的训练。当然KD本身还有很多局限,比如当类别少的时候效果就不太显著,对于非分类问题也不适用。我们目前有一些尝试试图突破这些局限性,提出一种通用的Knowledge
Transfer的方案。希望一切顺利的话,可以早日成文和大家交流。 ?
知识蒸馏的方式
我们采用一个温度T对深度网络的输出进行蒸馏,重新改变原始的softmax输出,当给定logit时,我们重新计算概率q为:
q
i
=
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
q_i = \frac{exp(z_i / T)}{\sum_{j} exp(z_j / T)}
qi=∑jexp(zj/T)exp(zi/T)
一般来说, 温度T设置为1,当T更高的时候,教师网络产生的soft target将会更软。
给定一个输入,其Hard Target为【0,0,1】, Soft Target为【0.001, 0.149, 0.85】
进行soft后,其结果可能会变成【0.1,0.3,0.6】。从而提供更多的类间信息
在训练过程中,一般来说会结合两种客观损失函数。一种是原来的hard target,即one-hot的label信息,cross entropy with correct labels。另一种是soft target,即对教师网络的输出采用相同的温度进行cross entropy。
匹配logits可作为知识蒸馏的一种特殊方式
给定
C
C
C 为soft target带来的loss(对教师和学生模型输出的logits进行cross entropy), Teacher网络模型和Student的网络模型的logits分别为
v
i
,
z
i
v_i, z_i
vi,zi,输出的softmax概率分别为
p
i
,
q
i
p_i, q_i
pi,qi。
C
=
−
∑
j
=
1
K
p
j
l
o
g
q
j
C = - \sum_{j=1}^{K} p_j log q_j
C=−j=1∑Kpjlogqj
其中:
p
i
=
e
x
p
(
v
i
/
T
)
∑
j
e
x
p
(
v
j
/
T
)
p_i = \frac{exp(v_i / T)}{\sum_j exp(v_j / T)}
pi=∑jexp(vj/T)exp(vi/T)
q
i
=
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
q_i = \frac{exp(z_i / T)}{\sum_j exp(z_j / T)}
qi=∑jexp(zj/T)exp(zi/T)
对于
∂
C
∂
z
i
\frac{\partial C}{\partial z_i}
∂zi∂C,我们有
∂
C
∂
z
i
=
−
1
T
∑
j
=
1
K
p
j
1
q
j
∂
q
j
∂
z
i
\frac{\partial C}{\partial z_i} = - \frac{1}{T} \sum^K_{j=1}p_j \frac{1}{q_j} \frac{\partial q_j}{\partial z_i}
∂zi∂C=−T1j=1∑Kpjqj1∂zi∂qj
此时注意,C是求解教师网络模型的概率输出与学生网络模型的概率输出交叉熵。注意softmax函数的特性,不管对任何特定类别的概率输出,都会包含所有的logits。
因此在计算
∂
q
j
∂
z
i
\frac{\partial q_j}{\partial z_i}
∂zi∂qj的时候要考虑
i
i
i与
j
j
j的关系、
- 当
i
i
i!=
j
j
j的时候:
∂ q j ∂ z i = − e z i e z j ( ∑ k e z k ) 2 = − q i q j \frac{\partial q_j}{\partial z_i} = \frac{-e^{z_i} e^{z_j}}{(\sum_k e^{z_k})^2} = - q_i q_j ∂zi∂qj=(∑kezk)2−eziezj=−qiqj - 当
i
=
=
j
i==j
i==j时:
∂ q j ∂ z i = q i ( 1 − q i ) \frac{\partial q_j}{\partial z_i} = q_i (1-q_i) ∂zi∂qj=qi(1−qi)
这样有:
∂
C
∂
z
i
=
−
1
T
∑
j
=
1
K
p
j
1
q
j
∂
q
j
∂
z
i
=
−
p
i
+
p
i
q
i
+
∑
j
=
1
,
j
!
=
i
k
p
j
q
i
=
q
i
−
p
i
\frac{\partial C}{\partial z_i} = - \frac{1}{T} \sum^K_{j=1}p_j \frac{1}{q_j} \frac{\partial q_j}{\partial z_i} = -p_i + p_i q_i + \sum^k_{j=1,j!=i}p_j q_i = q_i - p_i
∂zi∂C=−T1j=1∑Kpjqj1∂zi∂qj=−pi+piqi+j=1,j!=i∑kpjqi=qi−pi
按照论文中继续推导:
∂
C
∂
z
i
=
1
T
(
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
−
e
x
p
(
v
i
/
T
)
∑
j
e
x
p
(
v
j
/
T
)
)
\frac{\partial C}{\partial z_i} = \frac{1}{T}(\frac{exp(z_i / T)}{\sum_j exp(z_j / T)} - \frac{exp(v_i / T)}{\sum_j exp(v_j / T)} )
∂zi∂C=T1(∑jexp(zj/T)exp(zi/T)−∑jexp(vj/T)exp(vi/T))
当蒸馏温度T够大时,上式可表达为:
∂
C
∂
z
i
≈
1
T
(
1
+
z
i
/
T
N
+
∑
j
z
j
/
T
−
1
+
v
i
/
T
N
+
∑
j
e
x
p
(
v
j
/
T
)
)
\frac{\partial C}{\partial z_i} \approx \frac{1}{T}(\frac{1+z_i/T}{N+\sum_j z_j / T} - \frac{1+v_i/T}{N+\sum_j exp(v_j / T)} )
∂zi∂C≈T1(N+∑jzj/T1+zi/T−N+∑jexp(vj/T)1+vi/T)
若假设logits为零均值,即
∑
j
z
j
=
∑
j
v
j
=
0
\sum_j z_j = \sum_j v_j = 0
∑jzj=∑jvj=0,
则上式可以简化为:
∂
C
∂
z
i
≈
1
N
T
2
(
z
i
−
v
i
)
\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2} (z_i - v_i)
∂zi∂C≈NT21(zi−vi)
此时可以注意到,如果增加一个约束条件,蒸馏温度T->足够大,则蒸馏形式类似于去求解教师网络和学生网络的logits之间的均方差,即目标函数为 1 / 2 ( z i − v i ) 2 1/2(z_i - v_i)^2 1/2(zi−vi)2