知识蒸馏(Knowledge Distillation,KD)论文阅读笔记

一、引言

将一个预训练的大模型进行压缩和轻量化,使得模型能够部署在各种算力资源较少的嵌入式设备上,是如今的一个研究热点之一。知识蒸馏(Knowledge Distillation)是模型压缩的一个重要方法,其中KD方法可以说是知识蒸馏的起源,值得深入研究。

论文:《Distilling the Knowledge in a Neural Network》(NIPS 2014)

二、知识蒸馏方案的提出

1. knowledge transfer想法的提出

一个模型从诞生到投入实际应用,可以分为训练和部署两个主要阶段。在训练阶段,我们无需考虑训练成本、模型尺寸和实时性等要求,可以构建一个大型网络或者将许多简单模型集成到一起,投入大量算力、海量数据进行训练来达到我们的目标。但是,在部署阶段,我们需要考虑到硬件部署环境的算力资源、算法的实时性等,笨重的大模型往往难以满足部署要求。

因此,作者提出了一个方案:能否像老师教学生一样,把笨重的大模型学习到的知识转移给精简的小模型上,来满足我们的部署要求?

这就涉及到一个概念问题:对于要转移的知识应该如何定义?显然,把大模型学习到的权重等参数作为知识是行不通的。

2. knowledge的定义与模型的泛化能力

第一个想法是,以一个多分类模型为例,我们可以把大模型从softmax层输出的对各类别的预测概率/置信度作为soft targets作为要学习的知识提供给小模型训练。这里,soft targets这一概念相对于 非0即1的hard labels 给出。那么为什么不选择hard labels而选择soft targets呢?

类别soft targetshard labels
10.050
20.150
30.81

这是因为,我们训练一个模型的真正目的,不是为了让模型在一个给定的数据集上表现得有多好,而是为了让模型具备良好的泛化能力,即在未知数据集上也有很好的表现。

以手写数字分类器为例,在上面的表格中,我假设了一个输入图像的两种输出结果:一种是softmax层输出的置信度结果,另一种是离散的hard labels结果。两种结果表示的是一个意思,但是显然soft targets结果要包含更多的信息:输入的手写数字像3的概率是0.8,像2的概率是0.15,像1的概率是0.05。即:soft targets不仅包含了对正确答案的预测信息,还包含了对其他不正确答案的预测信息,这些信息都隐含了大模型实际学到的**“知识”**,对模型的泛化是非常重要的。此外,在不同的训练样本中,soft targets方案的梯度分散度更低,更soft,因此小模型可以比大模型使用更少的数据、更高的学习率来训练。

3. 蒸馏温度T

但是,使用softmax层直接输出的置信度结果即soft targets也存在缺点。论文在Chapter 1的末尾给出了一个同样是手写数字识别的例子来证明,笨重大模型往往对于正确答案给出很高的置信度,而对于其他不正确答案的置信度几乎趋近于0,导致在knowledge transfer阶段这些有价值的信息对交叉熵损失函数的影响很小。

于是,在原先思想的基础上,作者正式提出了知识蒸馏的方案:提升最终输出的softmax函数值的温度来使得大模型生成的targets变得softer以满足我们的要求;同时,对于小模型的训练我们使用同样的温度,这一过程就叫做“蒸馏”。这样,通过反向传播和梯度下降使蒸馏过程的损失函数最小化,我们使小模型的预测结果更加靠近大模型输出的soft targets,以达到知识转移的效果。

蒸馏时使用的训练集,也叫作转移数据集(transfer set),可以使用无标签的数据集,也可以使用训练大模型时使用的原始数据集。

三、知识蒸馏的具体方案

1. softmax函数计算公式(温度T)

首先,根据上一节对蒸馏温度T的定义,假设一个多分类网络输出的类别数为n,第i类对应的logits值为 z i z_{i} zi,给出温度T时softmax层第i类的预测概率值计算公式:
q i = e x p ( z i / T ) ∑ j = 1 n e x p ( z j / T ) q_{i}=\frac{exp(z_{i}/T)}{\sum_{j=1}^{n}exp(z_{j}/T)} qi=j=1nexp(zj/T)exp(zi/T)
观察发现,与往常的softmax函数相比,温度T下的softmax函数只是在指数位置加了个分母T,却能够起到使输出结果的概率分布变得softer的效果。

2. 知识蒸馏方案与流程

现在我们有了一个训练好的大模型作为教师,有了蒸馏公式,也给出了作为学生的小网络。如何来进行知识蒸馏呢?

首先给出知识蒸馏的流程框架如下,这个框架描述的非常清晰易懂
在这里插入图片描述
我从右往左来分析这个框架。

知识蒸馏本质上是对学生网络的训练过程。整个训练过程包含两个分支:

  1. 第一个分支由学生网络和训练好的教师网络在相同的温度T下进行蒸馏。面对transfer set中的一个样本,教师网络给出的预测结果作为soft targets或者说soft labels,即作为提供给学生网络的标签;而学生网络给出的预测结果作为soft predictions预测值。该分支的损失函数被称为soft loss或distillation loss,由labels和predictions作交叉熵计算得到。
  2. 第二个分支由学生网络在温度T=1下单独进行,提供人为打好的绝对正确的labels(hard labels),预测值由学生网络提供。该分支的损失函数被称为hard loss或student loss,同样由labels和predictions作交叉熵计算得到。

给出运算时涉及到的四种数据的范围。

数据取值范围温度
soft labels[0,1]T
soft predictions(0,1)T
hard labels{0,1}\
hard predictions(0,1)1

不难发现,其实第二个训练分支就是一个普通的模型训练过程。两个过程都是监督训练,一个监督来自teacher,另一个来自ground truth。如果把第一个分支比喻为老师指导学生学习的话,第二个分支就是学生对着教材书本进行训练。

为了使得学生网络向教师网络学到的结果或者说生成的soft targets靠拢,我们需要设置一个总的损失函数Total loss用于反向传播和梯度下降。如图,Total loss由两部分损失函数加权求和得到,并且当hard loss权重比soft loss权重小得多时效果更好。

四、实验

作者在经典的MNIST手写数据集上进行了实验验证。实验反映出了一些需要注意的地方。

1. 平移不变性的学习

我们知道,CNN具有平移不变性,即当图片中需要识别的物体的外观发生变化(如平移、旋转等)时,CNN也能将其识别出来。

在实验中作者发现,学生网络不仅能学习教师网络的分类知识,还能学习到教师网络的平移不变性,即使提供的transfer set中并没有包含任何经过变换的样本数据。

2. 蒸馏温度的设置

我们在前面已经分析过,当T=1时即一般的训练中,softmax层输出的类别预测概率分布相对比较hard。如果我们逐渐提高温度T,那么这一概率分布就会变得越来越soft。

然而,实验发现,当T很高时,概率分布会几乎失去起伏波动而趋向于一条水平线,这样分类就失去了意义,这也是我们不希望看到的结果。因此,设置合适的温度T非常重要。

实验验证说明,温度T的设置与学生网络的结构有关。用于实验的学生网络有两层隐含层,当每层包含300+个神经元时,温度T在8以上的训练结果差异不大且效果不错;当每层神经元被减少到30左右,温度T设置为2.5~4能取得更好的效果。

3. 零样本学习

实验还发现,当我们将transfer set中某一个类别的样本全部去掉后,学生网络依然能从教师网络学习到该类别的相关知识。这和前面提到的平移不变性学习是类似的。但是,学生网络对于该类别所犯的预测错误相对较多,并且犯下的错误多是由于学习到的关于该类别的偏置较小导致的。

举个例子,我们把MNIST中数字3的样本全部去掉,那么学生网络就会在识别3时更容易犯错,这是由于在训练过程中,学生网络学到的关于3这一类别对应偏置很小。因此,为了改善,我们在测试时需要人为调高类别3对应的偏置值。

类似地,如果我们只保留MNIST中数字7和8的样本进行训练,那么学生网络学到的7和8对应类别的偏置就会非常高。因此测试时我们需要人为调低这些偏置。

五、总结

KD这篇论文,是我在学习其他知识蒸馏论文时遇到瓶颈后,才选择进行阅读的。阅读之后,发现这篇论文对于知识蒸馏后续方法的学习,具有很大帮助,很值得进行精读。之后我会更加侧重对知识蒸馏方法的代码实践。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
知识蒸馏Knowledge Distillation)是一种将一个较大的模型的知识转移到一个较小的模型的技术。这种技术通常用于减小模型的大小和推理成本,同时保持模型在任务上的性能。 在Python中,你可以使用以下步骤来实现知识蒸馏: 1. 准备教师模型和学生模型:首先,你需要准备一个较大的教师模型和一个较小的学生模型。教师模型通常是一个预训练的大型模型,例如BERT或其他深度学习模型。学生模型是一个较小的模型,可以是一个浅层的神经网络或者是一个窄的版本的教师模型。 2. 训练教师模型:使用标注数据或其他训练数据集来训练教师模型。这个步骤可以使用常规的深度学习训练方法,例如反向传播和随机梯度下降。 3. 生成教师模型的软标签:使用教师模型对训练数据进行推理,并生成教师模型的软标签。软标签是对每个样本的预测概率分布,而不是传统的单一类别标签。 4. 训练学生模型:使用软标签作为学生模型的目标,使用训练数据集来训练学生模型。学生模型的结构和教师模型可以不同,但通常会尽量保持相似。 5. 进行知识蒸馏:在训练学生模型时,除了使用软标签作为目标,还可以使用教师模型的中间层表示或其他知识来辅助学生模型的训练。这可以通过添加额外的损失函数或使用特定的蒸馏算法来实现。 以上是实现知识蒸馏的一般步骤,具体实现细节可能因应用场景和模型而有所不同。你可以使用深度学习框架(如TensorFlow、PyTorch等)来实现这些步骤,并根据需要进行调整和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值