《Knowledge Distillation with the Reused Teacher Classifier》——论文解读
学习目标:
1.通过阅读学习如何提高学生模型的泛化能力,我们不止希望学生模型在训练集和测试集上表现优良(“课内作业”),我们同样希望它们能在“课外学业”上表现得令人信服。
2.了解什么是特征对齐,并且掌握基于正则化对齐的方法。
Abstract
蒸馏技术一直以来以将笨重而又复杂的教师模型转变为一个轻量级但是表现上与教师相差不多得学生模型而著名。但是很多人的故意设计反而增加了这个技术的难度,作者提出了一种比较简单的蒸馏方法:通过重用教师的分类器(用在学生模型上),只用一个L2正则化技术就可以将学生模型的表现和教师模型接近(前提是二者提取的特征完全对齐)。另外作者开发了“投影仪”技术,该“投影仪”可以用在各种教师学生模型上,并且和其他方法相比可以达到最好的效果。
下图为作者提出的结构图,作者的想法是在最终分类器的前一层特征层采用L2正则化技术实现特征对齐,在整个训练过程中,只有学生模型的Encoder(编码器)和Projector(投影仪)进行更新。
什么是特征对齐?
在深度学习中,特征对齐通常涉及到在训练过程中引入额外的约束或损失,以确保模型对输入数据的不同变体具有一致的表示。这对于处理域适应(domain adaptation)和多模态学习(multi-modal learning)等任务非常有用,因为不同的领域或模态可能具有不同的特征分布。
我在这里放另一篇论文的解读链接(里面有论文链接),有兴趣的小伙伴可以去深入了解一下特征对齐
什么是基于正则化的特征对齐?
客户端需要同时考虑监督学习损失和泛化误差来更新局部模型。为此,我们利用全局特征质心,并在局部训练目标中引入新的正则化项,使局部表示学习受益于全局数据。局部正则化项如下所示:
图1 基于正则化的特征对齐
其中 fθi (xj ) 是给出的xl的局部特征嵌入层,Cyj是yj类对应的全局特征质心,λ是平衡监督损失和正则化损失的超参数,通过利用全局语义特征信息,这样的正则化术语对每个客户机都有很大的好处。直观地说,它使每个客户端能够通过显式的特征分布对齐来学习任务的不变表示形式。因此,局部特征提取器{θi} mi =1的多样性也可以在最小化局部分类误差的同时进行正则化(也就是网络的参数)。
上式在论文推理过程中的来源为(了解即可):
直观地说,这个误差项(上式)是由局部-全局特征表示的不一致直接引起的,这意味着学习到的特征分布不紧凑。
然而,在训练损失和泛化差距之间存在权衡,同时最小化两者是不可行的!
所以作者想到了用图一的形式来显式地对齐局部-全局的特征表示
我们回到我们要分析的论文
1. Introduction
在介绍部分,首先提出知识蒸馏技术knowledge distillation (KD)已经在各个领域展开并表现良好,但关键问题还是在于学生模型和教师模型的性能差异依然明显,有很多人在此项任务上提出改进,但都差强人意。随后作者的方法横空出世。
作者观点(黑体字):教师模型强大的类别预测能力不仅归功于这些表达特征,而且同样重要的是,它是一个分类判别器。基于这一论点,我们通过在分类器的前一层的特征对齐来训练学生模型,并直接复制教师分类器进行学生推理。这样,如果我们能够将学生的特征与教师模型的特征完美地结合起来,他们的表现差距就会消失。也就是说,仅特征对齐误差就可以解释学生推理的准确性,这使得我们的知识迁移更容易理解。根据实验结果,一个单一的L2损失对于特征对齐的效果已经非常好了。这种简单的损失使我们不必像以前的工作那样仔细调整超参数,以平衡多重损失的影响。
由于从教师和学生模型中提取的特征尺寸通常不同,因此在学生特征编码器之后添加一个投影仪来弥补这种尺寸不匹配。这种投影仪通常在师生压缩中产生不到3%的剪枝比成本,但它使我们的技术适用于任意模型架构。在少数情况下,添加的投影仪加上重用的教师分类器的参数数小于原始学生分类器的参数数,甚至可以扩大剪枝比。我们在标准基准数据集上进行了广泛的实验,并观察到我们的SimKD始终优于各种师生架构组合的所有比较先进的方法。我们还表明,我们的简单技术在多教师知识蒸馏和无数据知识蒸馏等不同场景下都有很好的泛化效果。
作者上述使用投影仪技术使得教师模型和学生模型特征对齐的方法,在特征蒸馏中也有出现。
图2 特征蒸馏示意图
作者观点:想要减少教师和学生之间的性能传递下降,就需要多使用预训练教师模型的更多信息。
2. Related Work
随后作者提出相比于之前的利用各种什么梯度信息,他们能够运用假设迁移学习的理论将教师分类器用在学生模型上,并且加上L2正则化,在少量的有标签目标数据集和大量无源目标数据集上训练,仍然取得非常不错的成果。
3. Method
终于进入作者的方法这一章,作者通过对比三种蒸馏方式来阐述自己模型的合理性
三个模型的主要区别在于如何形式化梯度以及梯度流从哪里开始。Vanilla KD计算类别预测中的梯度,并依靠这个梯度来更新整个学生模型。(个人认为这是一种非常传统的方式,并且效果很不理想);特征蒸馏在上文已经提及,其通过各种知识表示从中间层收集更多的梯度信息。额外的超参数需要小心地转换以获得最佳性能(个人认为本质上还是是Teacher将特征级知识迁移给Student);而作者提出的SimKD在分类器的前一层计算L2损耗,并单独反向传播该梯度来更新学生特征编码器和维度投影仪。
3.1. Vanilla Knowledge Distillation
作者先卖了个关子,提了以下深度学习用于图像分类的深度学习模型常规架构,就是多层MLP+softmax,但是在知识蒸馏中,需要加入温度“T”,这个T就加在softmax函数里(或者其他激活函数),具体公式如下图:
其中
g
s
=
W
s
f
s
,而
f
s
是上一层的输出层,
W
s
是最后一层的权重矩阵
其中g^s=W^sf^s,而f^s是上一层的输出层,W^s是最后一层的权重矩阵
其中gs=Wsfs,而fs是上一层的输出层,Ws是最后一层的权重矩阵
下图为Vanilla Knowledge Distillation的损失函数,与交叉熵损失相比,引入的预测对齐损失为不正确的类提供了额外的信息,以方便学员的训练
3.2 Simple Knowledge Distillation(包含特征蒸馏)
虽然引入了中间隐藏层的梯度信息,然而,它们的成功很大程度上依赖于那些特别设计的知识表示,以引入适当的归纳偏差,并且精心选择超参数来平衡不同损失的影响。两者都是劳动密集型和耗时的。也很难断定某一类型的表征在学生培养中所起的实际作用。
3.3 SimKD(作者模型)
作者观点:相比之下,我们提出了一种简单的知识蒸馏技术,称为SimKD,它打破了这些严格的要求,同时仍然在广泛的实验中获得最先进的结果。如图c所示,SimKD的一个关键要素是“分类器重用”操作,即我们直接借用预先训练好的教师分类器进行学生推理,而不是训练一个新的分类器。这样就不需要标签信息来计算交叉熵损失,使特征对齐损失成为生成梯度的唯一来源。
下面就是作者在解释他们的模型为什么合理了(其实任何模型的提出都要自圆其说,这也是语言的艺术):
总的来说,我们认为教师分类器中包含的差异性信息 很重要,但在KD的文献中被很大程度上忽视了。然后,我们为其重要作用提供了一个合理的解释。考虑这样一种情况:一个模型被要求处理多个具有不同数据分布的 任务,基本做法是 冻结或共享一些浅层作为跨不同任务的特征提取器,同时对最后一层进行微调以 学习特定于任务的信息。在这种单模型多任务设置中,现有的工作认为任务不变信息可以共享,而任务特定信息需要独立识别,通常由最终分类器识别。对于KD,在同一数据集上训练具有不同能力的教师和学生模型,类似地,我们可以合理地认为,在不同模型中容易获得的数据中存在一些能力不变的信息,而强大的教师模型可能包含更简单的学生模型难以获得的额外基本能力特定信息。此外,我们假设大多数特定于能力的信息都包含在深层中,并期望重用这些层,即使只有最终分类器也会对学生训练有帮助。
可以看得出作者是在用心思考的,学生通过学习教师的深层参数(体现在最终分类器上),实现了对参数的提取,从而减少了大量的超参数调整工作和计算。
作者把损失函数的重点放在了中心正则化上,如下:
这个公式以前就出现了,但是作者只是在探究这个公式对于教师-学生模型的潜在应用价值。
SimKD具有良好的可解释性。请注意,来自预训练的教师模型的重用部分允许合并更多的层,但不仅限于最终的分类器。通常,重用更多的层可以提高学生的准确率,但会增加推理的负担。
4. Experiments
4.1 与其他知识蒸馏模型对比
实验部分的参数和数据集的具体细节在这里不在列出,我们重点看一下实验结果:
整体性能比较,通过简单的l2和公用classifier的方式达到了一个较好的性能:
可以看到,SimKD在各个数据集上的表现都略胜一筹
4.2 分类器-重用操作分析
(1)联合训练
作者联合训练学生特征编码器及其相关的分类器,然后使用他们自己的分类器或重用的教师分类器报告学生模型的测试准确性:
为了验证重用teacher模型的classifier的有效性,作者还提供了一组实验,训练student使用classifier,在推理的时候使用不同classifier的性能对比。从图中可以看到SimKD是本文提出的方法,KD是传统再试试蒸馏的方法。中间两条线分别是训练classifier时不同损失系数的性能。SimKD的性能的最高(也就是训练的时候仅仅使用特征损失不训练分类器,测试直接使用teacher的分类器效果最好)。而训练的时候同时训练分类器则对最终的性能有影响。
其中损失函数定义为:
(2)连续学习
在进行特征对齐后,作者固定学生特征编码器,即冻结提取的特征,只训练最后的softmax层,得到的精准度直线下降,如该表所示
这表明,即使提取的特征已经对齐,训练一个令人满意的学生分类器仍然是一个挑战。通常,我们可以通过调整分类器训练中的超参数来获得更好的学生性能。
此外,作者也提出,不仅仅重用分类器,重用更多的teacher特征也可以带来准确率的提升比如再继续重用最后一两个block(SimKD+ 和SimKD++)。如下图所示
4.3 Projector Analysis
projector的消融,说明一层还是不足以将特征合理的与teacher进行映射,不同的映射方式带来的效果也不同,如下图:
对于未来的应用: