Distilling the knowledge in a Neural Network
这篇文章出自概念的提出者,深度学习界的顶级大牛Hinton之手,以下内容为本人阅读此文时的感受,文献地址:知识蒸馏
读文章的第一步当然是从摘要出发,看本文到底是干什么的,有怎样的作用,把握了中心思路再去阅读才能起到事半功倍的效果,否则引言和相关工作就能让你的头变大,开始!
摘要
在相同的数据集上训练不同的模型,然后取这些模型的预测平均值是能够提高几乎所有的机器学习方法性能的!因为计算代价太大,并且若是单个的模型为深度神经网络,模型几乎不能够部署到用户端。
在2006年 ACM的一篇文章中已经证明,知识是可以进行蒸馏的(模型压缩)。而本文应用了一种不同的压缩技术来进行知识的蒸馏,通过将集成模型的压缩实现知识的蒸馏,在手写数字数据集MNIST上取得了非常好的效果,并显著提升了一款用于商业的语音模型。本文还进一步介绍了一种新型的模型集成方法,它能够让模型学习区分细粒度的类,而这些类在之前的模型中是分不太清楚的。并且不同于各个模型的简单混合,这些模型可以快速并行地训练。
小结:
1、知识蒸馏的定义
2、本文提出的一种新的模型压缩技术
3、新的模型集成方法
引言
许多昆虫都有幼虫(larval)形态,可以从环境中提取能量和营养,还有完全不同的成虫形态,适合旅行和不同的繁殖需求。就像我们的模型一样,在训练的时候如同幼虫,吸取数据的特征,而在商业化部署的时候需要一个精简的成熟状态。在大规模机器学习中,我们通常使用训练阶段和部署阶段非常相似的模型,尽管其要求非常不同:对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但不需要实时操作,它可以使用大量的计算。然而,部署到大量用户,对延迟和计算资源有更严格的要求。与昆虫的类比表明,如果我们更容易从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是一个单独训练的模型的集合,也可以是一个使用非常强的正则化器训练的非常大的模型,如dropout。一旦繁琐的模型被训练好,就可以使用一种不同的我们称之为“蒸馏”的训练过程,从繁琐的模型中转移出知识,使得它更加易于部署。里奇·卡鲁阿纳和他的合作者已经开创了这个策略的一个版本。在他们的论文中,已经证明把大型模型集合所获得的知识可以转移到单一的小模型中。
通过深度学习训练的模型所获取到的知识是一大堆的参数,这就使得解释我们是如何做到在保留已有的知识的同时改变模型的形式 变得很难。
对于知识的一个更加抽象即放之四海而皆准的定义:学习输入到输出的一个匹配关系。
大模型需要学习可区分大量的类别的特征,目标函数一般设定为
a
r
g
m
a
x
(
l
o
g
P
(
p
r
e
d
i
r
c
t
=
g
r
o
u
n
d
t
r
u
t
h
)
)
argmax (log P(predirct = groundtruth) )
argmax(logP(predirct=groundtruth))
然而,这种普通的方法会导致一种副产物:存在着很多不正确的答案的概率依然存在,并且这些情况中的某些概率值之间相差甚远。
文中举了个例子:宝马车被错分成垃圾车的情况,这种情况出现的概率本身很小,但是与将宝马车分类为胡萝卜的情况相比,却要大得多。大体表明模型在训练的过程中只关注与ground truth(真实标签)的对比,而并不关注与其他的错误类的比较情况。
模型的训练是为了取得更好的表现效果,不仅是在训练集上,更重要的,也证明这个模型是有效的,是在测试集上的效果,更好的去解决未知的问题,即在真实的场景下有好的泛化性能。
有目共睹的事实
在蒸馏的过程中,还是发现集成学习的方法所拥有的泛化性能更佳,大小模型都是如此,因为取得是多个模型分类结果的平均值,肯定是要比单个模型傻乎乎的训练要好一些的。
本文重点:
将集成模型的方法转移到小模型上,传功大法:
方 法一: 使用笨重的大模型产生的类别概率的”“软标签“‘来训练小模型。
对于这一步转移,可以使用相同的训练集或者是绝对的大模型转换过的集合。
当复杂模型是一大堆简单模型的组合时,使用各个分类器的预测分布的算术平均或者几何平均作为软标签,如果这种软标签的熵很高时,有以下几个优点:
1、 每次训练所携带的信息更多
2、训练的样例中的梯度没有相差太大(应该是)
3、小模型的训练需要的训练集可以更小
4、小模型的学习率可以更大
对于像MNIST这样的任务,大模型总是非常自信产生正确的答案。所以学习函数的信息来源都是在那些概率小的标签里面。
举个栗子:对于数字2,可能出现这种情况,这个数据判断为3的可能性为1e-6, 而为7的可能性为1e-9,就这个例子而言,就是它是有可能被认为是3或者7的,因为2 3 7之间有相似的结构,但是由于概率太小,所以在使用交叉熵损失函数的时候基本没有影响。
而在里奇·卡鲁阿纳的文章里提出使用logits函数而非softmax来学习小模型,并且优化函数的目标变为,由大模型和小模型产生的logits函数映射结果的平方差函数。
方法二:
本文提出一种新的解决方案,称之为蒸馏,distillation:在大模型产生一个可是的软标签集合前,提升关于最后的softmax概率分布的温度 T ,也是本文的中心调控变量!!然后使用带有高温的软标签集合来训练小模型。
后面也证明方法一是蒸馏的一种形式。
用来训练小模型的数据集可以时无标签的也可以是原始的训练集,本文指出,使用原始的训练集效果更好。特别是当我们在目标函数中添加小的机制,使得模型匹配真实值和软标签的性能都好。
即 软标签和真实标签都能够为小模型的训练作辅助!
何为蒸馏
公式一:
q
i
=
e
x
p
(
z
i
/
T
)
Σ
j
e
x
p
(
z
i
/
T
)
q_i = \frac{exp(z_i/T)}{Σ_jexp(z_i/T)}
qi=Σjexp(zi/T)exp(zi/T)
公式中的T一般来说取1,但是T越高,标签越软。(T代表温度,化学中蒸馏所需的温度是高于沸点的)这里理解为 T > 1.eq
对于公式的比较,设T = 3, 存在三个类,此次取得值为 (0.125 、 0.25 、 0.625),代入公式有
T
=
1
时
,
q
=
(
e
0.125
e
0.125
+
e
0.25
+
e
0.625
,
e
0.25
e
0.125
+
e
0.25
+
e
0.625
,
e
0.625
e
0.125
+
e
0.25
+
e
0.625
)
,
T = 1时, q = (\frac{e^{0.125}}{e^{0.125}+e^{0.25}+e^{0.625}}, \frac{e^{0.25}}{e^{0.125}+e^{0.25}+e^{0.625}},\frac{e^{0.625}}{e^{0.125}+e^{0.25}+e^{0.625}}),
T=1时,q=(e0.125+e0.25+e0.625e0.125,e0.125+e0.25+e0.625e0.25,e0.125+e0.25+e0.625e0.625),
T
=
2
时
,
q
=
(
e
0.0625
e
0.0625
+
e
0.125
+
e
0.3125
,
e
0.125
e
0.0625
+
e
0.125
+
e
0.3125
,
e
0.03125
e
0.0625
+
e
0.125
+
e
0.3125
)
T = 2时, q = (\frac{e^{0.0625}}{e^{0.0625}+e^{0.125}+e^{0.3125}}, \frac{e^{0.125}}{e^{0.0625}+e^{0.125}+e^{0.3125}},\frac{e^{0.03125}}{e^{0.0625}+e^{0.125}+e^{0.3125}})
T=2时,q=(e0.0625+e0.125+e0.3125e0.0625,e0.0625+e0.125+e0.3125e0.125,e0.0625+e0.125+e0.3125e0.03125)
由于式子的值的大小会随着T的增大而增大,后续接着证明。
下面介绍了两个重要的公式,也是本文的核心公式。
公式一:
∂
C
∂
z
i
=
1
T
(
q
i
−
p
i
)
=
1
T
(
e
z
i
/
T
∑
j
e
z
j
/
T
−
e
v
i
/
T
∑
j
e
v
j
/
T
)
\frac{\partial C}{\partial z_i} = \frac{1}{T}(q_i - p_i)=\frac{1}{T}(\frac{e^{z_i/T}}{\sum_je^{z_j/T}}-\frac{e^{v_i/T}}{\sum_je^{v_j/T}})
∂zi∂C=T1(qi−pi)=T1(∑jezj/Tezi/T−∑jevj/Tevi/T)
公式二:
i
f
T
>
>
1
,
则
有
if T >> 1,则有
ifT>>1,则有
∂
C
∂
z
i
≈
1
T
(
1
+
z
i
/
T
N
+
∑
j
z
j
/
T
−
1
+
v
i
/
T
N
+
∑
j
v
j
/
T
)
\frac{\partial C}{\partial z_i} ≈ \frac{1}{T}(\frac{1+z_i/T}{N+\sum_jz_j/T}- \frac{1+v_i/T}{N+\sum_jv_j/T})
∂zi∂C≈T1(N+∑jzj/T1+zi/T−N+∑jvj/T1+vi/T)
∵
∑
j
z
j
=
∑
j
v
j
=
0
,
∴
上
式
等
价
于
∵\sum_j z_j = \sum_j v_j = 0,∴上式等价于
∵∑jzj=∑jvj=0,∴上式等价于
∂
C
∂
z
i
≈
1
N
T
2
(
z
i
−
v
i
)
\frac{\partial C}{\partial z_i} ≈\frac{1}{NT^2}(z_i-v_i)
∂zi∂C≈NT21(zi−vi)
所以在高温的控制下,蒸馏就变成了
a
r
g
m
i
n
(
1
2
(
z
i
−
v
i
)
2
)
argmin(\frac{1}{2}(z_i-v_i)^2)
argmin(21(zi−vi)2)这样的形式,也是本文最终要优化的目标函数。
实验
首先本文在MNIST数据集上做了理论验证,并且是十分详尽的,且看:
第一步,60000张图片,在含有隐藏单元两层,每层1200个的网络中进行训练,其中使用了dropout和权重限制技术,网络最后实现了只有67张错误的测试效果。
第二步,设计一个800单元每层的还是两层隐藏层的网络,并且在没有使用dropout和其他正则化器的情况下实现了有146张的测试错误结果,并且在使用T=20加上正则化,在软标签的训练中,错误的张数降低到了74。这表明蒸馏是有效的。
第三步,进一步降低数量量至每层300,发现T=8时也能得到以上述相似的结果。
第四步,降低到每层只有30个隐藏单元,这时情况发生了变化,设T=2.5~4的情况下,蒸馏取得的效果会比T在其他的情况要好!
最后,介绍一下神奇的对照试验。
收取所有的3字图片,让蒸馏小模型看不见3,所得到的结果令人惊奇,1010张图片,206张错误,其中133张是3,主要看不见,然后稍微改变一下,bias=3.5,效果好了一些,109张错误分类,14张3,如果将7和8留下,则模型的错误率会上升到47.3%,改一下偏置,下降到13.2%.
总之就是这么干,可以!
后续的两个实验进一步证明了上述的理论,由大模型们组成的团队产生相应的软标签,蒸馏小模型在特定的温度下进行训练,结果很棒,即知识的转移时成功的,因为相比较原始的模型,在参数量下降的前提下依然做到了测试集上效果的保持。
下次一定记得及时保存,写了一大堆,突然一个意外关了浏览器,就没了!!!心态不太好!