知识蒸馏简述(一)

写在前面

这是一篇关于知识蒸馏的简述文,为了帮助读者以及我自己能对知识蒸馏的发展脉络有一个清晰的了解进而有所感悟,所以才决定写下这篇水文

本文根据student网络知识的来源,将知识蒸馏分为四大类:output logits transfer,output transfer,middle layer transfer,later hidden layer transfer,每一个分类将用一篇水文来介绍

介绍内容总体分为papers和讨论两部分,paper中的每一篇论文将从方法,亮点来讲述

本篇水文就从output losgits transfer开始讲述

output logits transfer

output logits指分类或识别任务softmax激活函数之前的网络最后的输出层

目录

Logits

论文信息

方法

亮点

KD from Noisy Teachers

论文信息

方法

亮点

Rocket Launching

论文信息

方法

亮点

KD with CAN

论文信息

方法

亮点

讨论


Logits

论文信息

论文题目:Do Deep Nets Really Need to be Deep? 

论文链接    论文解读

方法

深层 CNN ensemble 模型用于 CIFAR-10 用于生成 logits(before the softmax activation.), 然后使用老师的logits作为回归目标来训练学生网络, 从而完成对老师网络的模仿. 通过直接在logits上训练学生模型,学生可以更好地学习老师学到的内部模型,而不会丢失掉因由于softmax的竞争性特性会损失掉输入的信息.

亮点

知识蒸馏开山之作

KD from Noisy Teachers

论文信息

论文题目:Deep Model Compression: Distilling Knowledge from Noisy Teachers

论文链接    论文解读    论文实现

方法

  • 损失函数介绍

之前的的logits transfer如式(1)所示

本篇论文中,作者在z中加入扰动,如(2)式所示,让student学习扰动后的logits

ξ 是一个均值为0,方差为σ的随机向量,σ是超参数

  • 与正则化的关系

  • 方法流程

通过概率alpha,在mini-batch中选择一部分样本做扰动

通过式3计算损失函数

更新参数

亮点

在teacher网络的logits中加入了扰动,相当于一个老师变成了多个老师

放宽了logits的规则,有些正则化的意思

Rocket Launching

论文信息

论文题目:Rocket Launching:A Universal and Efficient Framework for Training Well-performing Light Net

论文链接    论文解读

方法

  • 框架

  • 损失函数

  • 同时训练

同时训练使得light net不仅可以学到booster网络的最终的输入,而且可以学到booster是怎么一步一步走到最终的结果的,这个过程教导是之前的知识蒸馏方法没有的

  • Gradient block

booster和light net同时训练时,hint loss会阻止booster直接从target中训练,从而影响到他的表现,而light net的知识又是从booster中来的,这也会进而影响到light net的表现

为了解决这个问题,作者采用gradient block阻止hint loss传入booster,这使得Wb不接受来自于hint loss的导数

亮点

同时训练的方式,使student不仅学到了最终的知识,还学到了如何走到终点的过程

共享参数的做法使得student网络有了更强的底层表达能力

gradient blok阻止了KD loss传入teacher,防止了student网络干扰teacher网络

KD with CAN

论文信息

论文题目:Training Shallow and Thin Networks or Acceleration via Knowledge Distillation with Conditional Adversarial Networks

论文链接    论文解读    论文实现

方法

  • 框架

像普通的GAN一样,D用来分辨输入是有teacher网络生成的(Real)还是由student网络生成的(Fake)

student网络充当G的功能,尽可能生成让D不能分辨的输出

  • LOSS

1. 判别器损失:

  xi表示student和teacher网络的输入,ti表示teacher网络的输出

  D表示判别器,F表示student网络

  判别器损失由两部分组成,一部分是对抗损失LA,另一部分是对齐损失LDS

  对齐损失是为了解决最终student网络输出与真实标签不匹配的问题,不不影响整体架构的了解,这里不展开说明

2. 生成器(student网络)损失:

  生成器损失由三部分组成,一部分是真实标签的交叉熵损失LS,一部分是辅助logits学习的L1损失,还有一部分是对抗损失LGAN

  LS是为了使用标签中原有的信息

  加入L1的帮助是考虑到生成器的能力太弱,会导致收敛速度很慢

  • 训练过程

判别器目标使使(5)式越大越好

生成器目标是使(7)式越小越好

采用交替训练的方式

判别器训练时,固定生成器,只更新判别器的参数

生成器训练时,固定判别器,只更新生成器参数

亮点

使用GAN帮助student学习teacher网络的logits,这可以保留logits的多样性

如上图所示,这两种分布都是比较合理的标签形式,都可以作为知识来训练student,但是直接用teacher分布来训练会导致student学到的分布过于单一化,而让GAN去分析student网络学到的东西是否合理,合理就鼓励,可以增加学到的分布的多样性

讨论

这几篇论文从logits的简单回归,到后来加入noisy,再到后来加入学习过程的学习,最后通过GAN来增强logits的多样性,可以从整个发展流程看出,作者们都是想要让知识变得越来越丰富,只是用的方法和额外的知识来源不同

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值