《理解知识蒸馏》
本文是对知识蒸馏方法在分类场景的应用来讨论的。知识蒸馏是一种从复杂模型向简单模型迁移知识的方法,我觉得和Label smooth有相似的地方,就是soft-labels,soft-labels有正则的功能,而且具有容错性。但是显然知识蒸馏得到的soft-labels会包含更多的所谓dark knowledge,能够让模型get到更多的信息。
Key Words:Label smooth、dark knowledge、knowledge distillation
Beijing, 2020
作者:RaySue
Code:
提出问题
当两个类别共享一系列特征的时候,一个类别的硬概率为100%,而另一个类别的概率为0,将会对每个类别学习的共享特征产生不利影响,导致深度学习分类模型的整体准确率下降,比如两个类别很像的情况,就会导致模型学习遇到困惑。
Label Smooth
对于上述问题,可以对其one-hot label进行平滑处理即Label Smooth,具体的做法也很简单,就是利用一个参数对非0即1的label进行软化处理,比如:使得1变为0.96,其余的类别平分0.04。可以缓解上述的问题。
这样的做法直观上就是让模型不是完全的相信label,这样数据中有一些模棱两可的处于边缘的样本,就不会因为人的主观分类,导致模型受到影响,这也是label smooth能够增加模型泛化性的原因。
更多细节,参见博客https://blog.csdn.net/racesu/article/details/107214035
知识蒸馏
Matching logits is a special case of distillation
符号说明
- logits:未被归一化的对数几率 log(Odds)一个事件发生与该事件不发生的比值的对数
- teacher: 原始模型或 Model Ensemble
- student: 较简单模型
- transfer set: 用来迁移teacher知识、训练student的数据集合
- soft target: teacher输出的预测结果(一般是softmax之后的概率)
- hard target: 样本原本的标签
- temperature: 蒸馏目标函数中的超参数
- born-again network: 蒸馏的一种,指student和teacher的结构和尺寸完全一样
- teacher annealing: 防止student的表现被teacher限制,在蒸馏时逐渐减少soft targets的权重
知识蒸馏的概念
知识蒸馏首先被 Bucila 等人在2006年提出来,然后被 Hinton 在2015年推广。
为了解决提出的问题,可以使用通过高准确率的cumbersome模型得到的软化的分类概率分布来训练模型,有和Label Smooth类似的效果。
知识蒸馏也是模型压缩的方法,一个小模型来模仿一个预训练的大模型(或集成模型)。这种训练方式也被称为“teacher-student”,其中的大模型被称为teacher,小模型被称为student。
在蒸馏阶段,知识从teacher模型向student模型通过最小化一个损失函数来迁移,这一项的学习的目标是通过teacher模型预测得到的类别概率的分布。
蒸馏的温度
数据集中的很多easy样本在teacher model预测的正确标签的概率是非常高的而其他类别的概率就为0了,这就导致了和hard target是一样的情况了。这样蒸馏方法就无法比数据集的真实标签提供更多的信息了。为了处理这个问题,Hinton等人介绍了"softmax temperature"的概念。
神经网络通常使用softmax层来转换logit产生分类概率,通过每个类别的logit, z i z_i zi 来计算其相应的概率值 q i q_i qi ,加入了参数"softmax temperature"之后就变为:
q i = e x p ( z i T ) ∑ j e x p ( z j T ) q_i = \frac{exp( \frac{z_i}{T})}{\sum_j exp(\frac{z_j}{T})} qi=∑jexp(Tzj)exp(Tzi)
其中 T T T 表示温度,当 T = 1 T=1 T=1的时候我们就得到了标准的softmax函数。
Hinton等人发现,除了teacher的soft labels真实的标签对于蒸馏模型效果也是有帮助的。因此,我们也计算学生预测的类别概率和真实标签的损失。我们称之为student loss。在计算这个loss的时候,我们设置T=1。
蒸馏的最简单形式中,知识通过在transfer set上训练来转移到蒸馏模型中,这个transfer set是笨重的模型通过在其softmax中利用高温度T推理得到的,迁移集合中的每个样本使用这种软化的目标分布作为label,在训练蒸馏模型中使用同样的高温,但是在训练蒸馏模型之前,笨重的模型需要用T=1先训练好。
整个distillation的损失函数由两部分组成:
s t u d e n t _ l o s s = C E ( g t , σ ( z s ; T = 1 ) ) ( 1 ) student\_loss = CE(gt, \sigma(z_s; T=1)) \space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space (1) student_loss=CE(gt,σ(zs;T=1)) (1)
d i s t i l l a t i o n _ l o s s = C E ( σ ( z t ; T = τ ) , σ ( z s ; T = τ ) ) ( 2 ) distillation\_loss = CE(\sigma(z_t;T=\tau), \sigma(z_s; T=\tau)) \space\space\space\space\space (2) distillation_loss=CE(σ(zt;T=τ),σ(zs;T=τ)) (2)
L ( x ; W ) = α ∗ s t u d e n t _ l o s s + β ∗ d i s t i l l a t i o n _ l o s s ( 3 ) L(x;W) = \alpha * student\_loss + \beta * distillation\_loss \space\space\space\space (3) L(x;W)=α∗student_loss+β∗distillation_loss (3)
- gt is the ground truth label (one-hot).
- z t z_t zt logit of teacher
- z s z_s zs logit of student
- CE is the cross-entorpy loss function.
- σ \sigma σ is the softmax function parameterized by the temperature T.
- α 、 β \alpha 、 \beta α、β are coefficients
其中的 τ , α , β \tau, \alpha, \beta τ,α,β是超参数,Hinton等人在实验中使用的 τ \tau τ 的范围从 1 到 20。他们发现 当student model比teacher model小很多的时候,更低的温度( τ \tau τ)会有更好的效果
因为Distillation loss是通过软化的目标求得的,所以产生的梯度量级就变为了 1 / T 2 1/T^2 1/T2,所以如果想同时使用hard和soft目标,则需要对Distillation loss部分乘以 T 2 T^2 T2,这样才能确保hard和soft目标保持相当的贡献。
知识蒸馏的理解
知识蒸馏过程中,把里面包含了大模型从数据中学习到的更具有泛化性的知识,迁移到较为简单的模型当中,从而让student能够学到更多的信息。
因为hard target的缺点之一是容易导致模型过拟合,所以知识蒸馏也是一种正则化的策略,提高泛化性,概率1和概率0鼓励所属类别和其他类别之间的差距尽可能加大,而由梯度有界可知,这种情况很难适应。会造成模型过于相信预测的类别,而知识蒸馏和LabelSmooth相似可以将 z i z_i zi 的范围限定一下,会避免这种问题。
随着 T T T的增加,通过softmax函数得到的概率分布就变得更加软化,提供了更多的信息,比如teacher model发现的哪些类别和预测的类别更相近。Hinton称之为嵌入在teacher model的“暗知识”,也就是我们在蒸馏阶段需要迁移到student model的暗知识。当计算和teacher的soft target的损失函数时,我们使用相同的温度 T T T 来计算学生的logits,我们称之为“蒸馏损失”。
论文在 MNIST 数据集合中验证了知识蒸馏的效果,一个比较让人惊讶的实验结果是,在transfer set中只使用MNIST除去数字3的数据,在数字3的1010个测试样本中只错了133个,这也是由于数字3没有参与训练导致的bias很低,如果bias调高到3.5那么在数字3的测试集合上只错了14个。
另一个实验更加惊艳,迁移集合中只使用 MNIST 数据集中的数字7和8来训练student model,测试集合的错误率是47.3%,同样因为数字7和8的bias太高了,降低到7.6后错误率仅为13.2%。
知识蒸馏训练的pipeline
知识蒸馏参数设定
温度 τ \tau τ 的选取
当student model比teacher model小很多的时候,较小的 T能够取得更好的结果。这也是解释的通的,因为当我们把softmax temperature调高了,意味着soft-labels分布包含了更加丰富的信息,而一个很小的模型可能无法捕捉其全部的信息。但是也没有一个明确的方式来预知student model会有捕捉什么程度信息的能力。
α , β \alpha,\beta α,β的选取
Hinton在选取这两个超参数的时候使用了加权平均的方法,即 β = 1 − α \beta = 1 - \alpha β=1−α。他们发现往往设置 α \alpha α 的值远低于 β \beta β 值的时候会得到最好的效果。比如 a l p h a = 0.05 , β = 0.95 alpha = 0.05, \beta = 0.95 alpha=0.05,β=0.95
总结 & 思考
对于多分类问题而言,它的标签只指向一个类别,而其他类别的概率为0,这样去做训练的时候,往往会导致模型拟合到一些细节,导致模型的泛化能力变低。而soft-labels能够在一定程度上避免这个问题,所以Label Smooth是有效的。
相比于LabelSmooth来说,Distillation产生的soft-labels隐式的含有类间关系的信息,增加模型泛化性的同时更能够增加模型的performance。
参数量相对多的模型都可以把知识迁移到参数量较少的模型上去,比如teacher model为ResNet-34,student model为ResNet-18,就可以作为一对知识蒸馏的teacher-student,在结构相似的情况下效果尤为明显。此外,一个模型可以通过其他的模型压缩的方式使其具有更少的承载力,也意味着更快的效率,比如稀疏化或量化。比如我们用模型量化的方式训练一个4-bit的ResNet-18,然后使用FP32的ResNet-18作为teacher model。同样也可以使用剪枝或正则化等手段联合使用。
知识蒸馏的形式是多种多样的,我们只需要理解其中的道理即可,也有的工程中直接使用student logits和teacher logits直接计算Mse的,也是蒸馏的一种实现。
参考
[1] Distilling the Knowledge in a Neural Network
[2] Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions
[3] 介绍了各种知识蒸馏 https://blog.csdn.net/nature553863/article/details/80568658
[4] Distiller(一个基于pytorch的 包含知识蒸馏及自动模型压缩,量化学习的项目)论文:https://arxiv.org/abs/1910.12232
[5] Distiller Repo: https://github.com/NervanaSystems/
[6] Distiller 如何训练知识蒸馏模型:https://nervanasystems.github.io/distiller/schedule.html#knowledge-distillation
[7] Distiller 介绍知识蒸馏 https://nervanasystems.github.io/distiller/knowledge_distillation.html
[8] KD 文章合集 https://github.com/FLHonker/Awesome-Knowledge-Distillation