【学习笔记】Distilling the knowledge in a neural network

在这里插入图片描述

背景:making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow de- ployment to a large number of users, especially if the individual models are large neural nets.
贡献: significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model.

1. 摘要

背景

要提高机器学习算法的性能,一个方法就是在同一个数据上训练许多不同的模型并平均预测,但是使用这种a whole ensemble of models(集成模型)是笨重的,需要的计算成本很高,无法部署带大量的用户。对此,有研究表明,将集成模型中的知识压缩到一个易于部署的小型的单个模型是可行的,文章也基于此展开

贡献

  • 在MNIST数据集上有令人惊喜的结果,将集成模型的知识提炼到单个模型中可以显著改进大量应用于商业系统的语音模型
  • 介绍了一种新的集成(ensemble)类型,它由一个或多个完整模型和许多专家模型组成,这些专家模型学习区分完整模型所混淆的细粒度类。与混合专家不同,这些专家模型可以快速、并行地训练

2. 介绍

2.1 distillation的提出

   在模型的训练和部署阶段,两者采用的模型是相似的,但是他们的要求是不同的,在训练阶段,需要从大量的冗余的数据中获取结构,但是不需要实时操作并且可以可以使用大量的计算,在部署到大量用户的阶段,对于实时性和计算资源的要求较为严格。
   基于此,文中提出了一种distillation的概念:当训练阶段的复杂模型(teacher)被训练出来后,那么就使用蒸馏的办法将知识转移到适合部署的小模型(student)上面。 这个复杂得模型可以是独自训练模型的集成,也可以是一个用强大正则器如dropout训练的单个大模型

2.2 相关知识点

1)教师网络可能可以从相关性的概率告诉学生如何去泛化

  对于学习区分大量类别的复杂模型,通常的训练目标是最大化正确类别的平均对数概率,但学习的副作用是,训练好的模型会为所有错误类别分配概率,即使这些概率非常小,其中一些也比其他概率大得多。错误类别的概率告诉我们很多关于复杂模型如何泛化的信息。例如,一张宝马车的图片可能只有极小的机会被误认为是垃圾车,但这种错误仍然比将宝马车误认为胡萝卜的可能性高很多倍。

2)文中指出使用distillate可以很好的让小模型和大模型一样有相似的泛化能力

方法:将复杂模型产生的类概率作为软目标来训练小型模型,概率中有很多隐式信息
优点:软目标(soft targets):相比起hard targets,具有高熵值(信息量多),并且训练每一个样本时的梯度差异更小。
基本知识参考

  • Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。比如(1,0,0):熵的大小
  • Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。比如(0.3,0.4,0.6):包含更多信息,熵更大
  • 信息熵: 随机变量取值个数越多,状态数也就越多,累加次数就越多,信息熵就越大,混乱程度就越大,纯度越小 参考
    H ( x ) = − ∑ i = 1 n P ( x i ) l o g ( P ( x i ) ) H(x)=- \sum_{i=1}^n P(x_i )log(P(x_i )) H(x)=i=1nP(xi)log(P(xi))

3)文中举的例子

参考
在MNIST数据集中做手写体数字识别任务,可以看到当左边的2更像3,右边的2更像7,这在soft-taget中给负样本3和7的概率相比于其他负样本来说是高的,这也再次说明soft target比hard target有更多的信息
注:这些信息对迁移阶段的交叉熵代价函数的影响很小,因为这些概率值接近零
在这里插入图片描述

4)Logits(逻辑单元)

参考
Logits:最终全连接层的输出。比如在分类问题中,当汇总了网络内部的各种信息,会得出各个类别的汇总分值   z i \ z_i  zi,这个就是Logits,通常下一步使用softmax和sigmoid进行归一化。得到最终分类结果的概率

  • soft激活函数
    在这里插入图片描述

需要注意:softmax在类别归属满足概率分布的同时,会放大Logits数值之间的差异,会使得Logits得分两极分化,那么一部分负标签会被压扁为0,导致在网络训练的时候丢失这部分信息
如下图所示:在网络训练时,类概率层(如图 3 中 T=1 的软目标)的父标签输出的信息基本已经丢失. 将该类概率作为学生的监督信号,相当于让学生学习硬目标知识[1]。为了解决这个问题,文中提出了带有温度T的类概率,来控制输出概率的软化程度。
在这里插入图片描述

3. distillation蒸馏

3.1 蒸馏

1)引入带有参数T的类概率

在这里插入图片描述

  • T=1,表示网络输出 Softmax的类概率,当T<1时,概率分布比原始更 “陡峭”, 当T→0 时, Softmax 的输出值会接近于hard-target,T>1 时, 概率分布比原始更“平 缓"
  • T=   + ∞ \ + \infty  +,此时表示网络输出的逻辑单元,此时softmax的值是平均分布的
对于温度T的理解

在知识蒸馏中,使用高温将知识蒸馏出来,然后再恢复低温T=1进行学生模型的测试。那么这个温度T需要控制为多少才能将需要的知识蒸馏出来呢?
可以看出,随着温度T的升高,softmax的输出分布越来越平滑,信息熵也会越来越大,那么在student模型的训练过程中对于负标签的关注也会增加,特别是那些概率值显著高于平均概率值的负标签。

基于此可以具体情况具体分析:

  • 当student模型较小的时候,可以把温度调低,这样负标签的干扰就会减少
  • 当想从负标签中学到一些信息量的时候,可以把温度T调高

3) 蒸馏知识框架

  • student模型在训练软目标1时,要在高温环境(T较大)下进行,同时与它做交叉熵的,是teacher模型在相同的温度(T较大)输出的软目标
  • 在训练结束后,其使用温度 T = 1.0 ,输出软目标2,总的优化目标是软目标和硬目标的加权和

在这里插入图片描述

4)蒸馏总损失

参考
 cost function  = λ  CrossEntropy  ( y s , y t ) + ( 1 − λ ) Cross ⁡ Entropy ⁡ ( y s ( T = 1 ) , y ) \text { cost function }= \lambda \text { CrossEntropy }\left(y_{s}, y_{t}\right)+(1-\lambda) \operatorname{Cross} \operatorname{Entropy}\left(y_{s}(T=1), y\right)  cost function =λ CrossEntropy (ys,yt)+(1λ)CrossEntropy(ysT=1,y)

CrossEntropy为交叉熵函数
$\ y_{s}$ 表示student模型的预测结果
$\ y_{t}$ 表示teacher模型的预测结果
 y 是student模型的真实标签(硬标签向量)。
  • 第一个交叉熵函数是student模型预测结果与软目标的交叉熵,其使用与训练该软目标的teacher中一样的较高温度
  • 第二个目标函数 是student模型与真实标签的交叉熵,其温度设置为 1.0

好的结果是在第二个目标函数上使用较低的权值,由于软目标产生的梯度缩放了   1 / T 2 \ 1 / T^{2}  1/T2 ,因而在同时使用软目标和硬目标时,将软目标的梯度乘以   T 2 \ T^{2}  T2 是非常重要的。这确保了在实验过程中,如果用于蒸馏的温度发生改变,那么硬目标和软目标的相对贡献大致保持不变

3.2 蒸馏的一种特殊形式:直接matching logits

直接Matching Logits:直接使用softmax层的输入logits(而不再是输出)作为Soft target, 需要最小化的目标函数是Teacher模型和Student模型的logits之间的平方差(zi是学生模型的logits,vi是教师模型的logits)
  L logits  = 1 2 ( z i − v i ) 2 \ L_{\text {logits }}=\frac{1}{2}\left(z_{i}-v_{i}\right)^{2}  Llogits =21(zivi)2
  z i \ z_{i}  zi求梯度可得:

∂ L logits  ∂ z i = z i − v i \frac{\partial L_{\text {logits }}}{\partial z_{i}}=z_{i}-v_{i} ziLlogits =zivi

再看一般蒸馏中   L s o f t \ L_{soft}  Lsoft   z i \ z_{i}  zi求梯度可得:

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{\partial C}{\partial z_{i}}=\frac{1}{T}\left(q_{i}-p_{i}\right)=\frac{1}{T}\left(\frac{e^{z_{i} / T}}{\sum_{j} e^{z_{j} / T}}-\frac{e^{v_{i} / T}}{\sum_{j} e^{v_{j} / T}}\right) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)

当 T → ∞ 时 , 有 z i T → 0 和 v i T → 0 , 根据泰勒公式的一阶展开 , 当 x → 0 时有 exp ⁡ ( x ) → x + 1 , 则有 : 当 T \rightarrow \infty 时, 有 \frac{z_{i}}{T} \rightarrow 0 和 \frac{v_{i}}{T} \rightarrow 0, 根据泰勒公式的一阶展开, 当 x \rightarrow 0 时有 \exp (x) \rightarrow x+1, 则有: T,Tzi0Tvi0,根据泰勒公式的一阶展开,x0时有exp(x)x+1,则有:

∂ C ∂ z i ≈ 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \frac{\partial C}{\partial z_{i}} \approx \frac{1}{T}\left(\frac{1+z_{i} / T}{N+\sum_{j} z_{j} / T}-\frac{1+v_{i} / T}{N+\sum_{j} v_{j} / T}\right) ziCT1(N+jzj/T1+zi/TN+jvj/T1+vi/T)

此时, 假设 Logits 在每个样本上是零均值(zero-meaned)的, 则进一步近似:

∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{\partial C}{\partial z_{i}} \approx \frac{1}{N T^{2}}\left(z_{i}-v_{i}\right) ziCNT21(zivi)

可见, 与直接Matching Logits的方式相比, 当温度   T → ∞ \ T \rightarrow \infty  T 时,经过Softmax的蒸馏方式的Soft-target损失函数部分与其是等价的, 即 Matching Logits是一般知识蒸馏方法的一种特殊形式。

4. 实验

此处只对列出了实验一和实验二,和蒸馏的思想相关,实验三是关于集成模型的,具体可看原文

4.1 MINST(手写数字识别)

1)数据集
  • teacher网络:首先在60000个训练样本上训练了一个带有两个隐藏层 (每层有1200个单元) ,该网络使用dropout和权重约束weight-constraints的正则化方法
  • student网络1:两个隐藏层 (每层有800个单元) ,不带正则方法
  • student网络2:添加了T=20的大网络软目标任务
  • 输入的图片:在任意方向上抖动了两个像素
2)错误结果
teacherstudent1student2
67个146个74个

表明:软目标可以将大量知识转移到蒸馏模型中,包括从翻译训练数据中学到的关于如何泛化的知识,即使转移集不包含任何翻译。

3)将迁移数据集中数字 3 的样本移除

蒸馏模型(学生模型)只产生了206次错误,而且其中大部分偏差都是由于学习的3类偏差太低

4.2语音识别(speech recognition)

  • baseline:Android语音搜索的声学模型
  • 10xEnsemble:使用不同的参数训练了10个DNN(用于自动语音识别ASR的深度神经网络)来预测,使用与基线完全相同的架构和训练数据,对这10个模型的预测结果求平均作为emsemble的结果,相比于单个模型有一定的提升,但是不那么显著,于是可以使用更简单的蒸馏
  • Distillted Single model:将上述10个模型作为teacher网络,训练student网络。得到的Distilled Single model相比于直接的单个网络,也有一定的提升

在这里插入图片描述

5 结论

  • 蒸馏对于将知识从集成或从大型高度正则化模型,转移到较小的蒸馏模型非常有效。
  • 在MNIST上,即使用于训练蒸馏模型的转移集缺少一个或多个类的任何示例,蒸馏也非常有效。
  • 总的来说就是蒸馏能够压缩模型的大小还能保持模型的性能

参考文献

[1]黄震华,杨顺志,林威等.知识蒸馏研究综述[J].计算机学报,2022,45(03):624-653.
文章大部分参考的是这个:写得太好啦


  • 11
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值