写在前面
这是一篇关于知识蒸馏的简述文,为了帮助读者以及我自己能对知识蒸馏的发展脉络有一个清晰的了解进而有所感悟,所以才决定写下这篇水文
本文根据student网络知识的来源,将知识蒸馏分为四大类:output logits transfer,output transfer,middle layer transfer,later hidden layer transfer,每一个分类将用一篇水文来介绍
介绍内容总体分为papers和讨论两部分,paper中的每一篇论文将从方法,亮点来讲述
本篇水文就从output losgits transfer开始讲述
output logits transfer
output logits指分类或识别任务softmax激活函数之前的网络最后的输出层
目录
Logits
论文信息
论文题目:Do Deep Nets Really Need to be Deep?
方法
深层 CNN ensemble 模型用于 CIFAR-10 用于生成 logits(before the softmax activation.), 然后使用老师的logits作为回归目标来训练学生网络, 从而完成对老师网络的模仿. 通过直接在logits上训练学生模型,学生可以更好地学习老师学到的内部模型,而不会丢失掉因由于softmax的竞争性特性会损失掉输入的信息.
亮点
知识蒸馏开山之作
KD from Noisy Teachers
论文信息
论文题目:Deep Model Compression: Distilling Knowledge from Noisy Teachers
方法
- 损失函数介绍
之前的的logits transfer如式(1)所示
本篇论文中,作者在z中加入扰动,如(2)式所示,让student学习扰动后的logits
ξ 是一个均值为0,方差为σ的随机向量,σ是超参数
- 与正则化的关系
- 方法流程
通过概率alpha,在mini-batch中选择一部分样本做扰动
通过式3计算损失函数
更新参数
亮点
在teacher网络的logits中加入了扰动,相当于一个老师变成了多个老师
放宽了logits的规则,有些正则化的意思
Rocket Launching
论文信息
论文题目:Rocket Launching:A Universal and Efficient Framework for Training Well-performing Light Net
方法
- 框架
- 损失函数
- 同时训练
同时训练使得light net不仅可以学到booster网络的最终的输入,而且可以学到booster是怎么一步一步走到最终的结果的,这个过程教导是之前的知识蒸馏方法没有的
- Gradient block
booster和light net同时训练时,hint loss会阻止booster直接从target中训练,从而影响到他的表现,而light net的知识又是从booster中来的,这也会进而影响到light net的表现
为了解决这个问题,作者采用gradient block阻止hint loss传入booster,这使得Wb不接受来自于hint loss的导数
亮点
同时训练的方式,使student不仅学到了最终的知识,还学到了如何走到终点的过程
共享参数的做法使得student网络有了更强的底层表达能力
gradient blok阻止了KD loss传入teacher,防止了student网络干扰teacher网络
KD with CAN
论文信息
论文题目:Training Shallow and Thin Networks or Acceleration via Knowledge Distillation with Conditional Adversarial Networks
方法
- 框架
像普通的GAN一样,D用来分辨输入是有teacher网络生成的(Real)还是由student网络生成的(Fake)
student网络充当G的功能,尽可能生成让D不能分辨的输出
- LOSS
1. 判别器损失:
xi表示student和teacher网络的输入,ti表示teacher网络的输出
D表示判别器,F表示student网络
判别器损失由两部分组成,一部分是对抗损失LA,另一部分是对齐损失LDS
对齐损失是为了解决最终student网络输出与真实标签不匹配的问题,不不影响整体架构的了解,这里不展开说明
2. 生成器(student网络)损失:
生成器损失由三部分组成,一部分是真实标签的交叉熵损失LS,一部分是辅助logits学习的L1损失,还有一部分是对抗损失LGAN
LS是为了使用标签中原有的信息
加入L1的帮助是考虑到生成器的能力太弱,会导致收敛速度很慢
- 训练过程
判别器目标使使(5)式越大越好
生成器目标是使(7)式越小越好
采用交替训练的方式
判别器训练时,固定生成器,只更新判别器的参数
生成器训练时,固定判别器,只更新生成器参数
亮点
使用GAN帮助student学习teacher网络的logits,这可以保留logits的多样性
如上图所示,这两种分布都是比较合理的标签形式,都可以作为知识来训练student,但是直接用teacher分布来训练会导致student学到的分布过于单一化,而让GAN去分析student网络学到的东西是否合理,合理就鼓励,可以增加学到的分布的多样性
讨论
这几篇论文从logits的简单回归,到后来加入noisy,再到后来加入学习过程的学习,最后通过GAN来增强logits的多样性,可以从整个发展流程看出,作者们都是想要让知识变得越来越丰富,只是用的方法和额外的知识来源不同