【知识蒸馏】(附代码)

目录

前言

一、Soft-target蒸馏

二、关于温度T

三、代码案例

参考文献


前言

        Hinton等人于2015年在文章《Distilling the Knowledge in a Neural Network》中提出了知识蒸馏这个概念(国内有一些不一样的声音,周志华老师在03年的时候, 就在论文中提出了相似的想法, 然后在Hinton的nips大会上演讲的时候, 周老师的学生就当面向Hinton提出来这个事情, 不过Hinton自始至终也没有回应),其核心思想是先训练一个复杂网络模型,然后使用这个复杂网络的输出和数据的真实标签去训练一个更小的网络

      Teacher-Student模式:将复杂,学习能力强的模型作为Teacher,结构较为简单,学习能力相对较弱模型作为Student,用Teacher来辅助Student模型的训练,将Teacher学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,作为导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

        知识蒸馏是将大模型的能力迁移到小模型上,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏)和基于特征蒸馏两个大的方向。本文主要讲述的是Soft-target蒸馏。

一、Soft-target蒸馏

        我们首先训练一个泛化能力较强的Teacher模型,在利用Teacher模型来蒸馏训练Student模型时,通过使用softmax层输出的类别的概率来作为“Soft-target” ,可以直接让Student模型去学习Teacher模型的泛化能力。

具体步骤:

  1. 训练好Teacher模型;
  2. 利用高温得到Soft-target;
  3. 使用{Soft-target,}以及{Hard-target,T=1}训练student模型;
  4. 设置T=1,使用student模型做线上的inference。

如何理解上面这个流程图呢?

       一批训练数据输入一个已经训练好的Teacher模型中和没有训练过的Student模型,注意两个模型中都需要设置高温“T=t”,将两个模型得到的结果做一个交叉熵得到distillation loss;同时Student模型在正常温度“T=1”下得到一个输出,将其和真实的标签信息做一个交叉熵得到Student loss;最后将两个loss进行加权求和得到最终的损失函数。     

一些定义(很多都是来自论文里):

Logit:输入数据经过神经网络进行各种非线性变换后,在网络最后Softmax层之前得到的属于各个类别的大小数值z_{i}

Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。

Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。

softmax函数

加了温度T后的softmax函数

高温蒸馏过程的目标函数

L-soft:用Teacher模型在高温 T下产生的softmax distribution来作为Soft-target,Student模型在相同温度条件下的softmax输出和Soft-target的cross entropy ,具体形式为:

       其中,p_{i}^{T}指Student的在温度等于 T 的条件下softmax输出在第 i 类上的值。q_{i}^{T}指Student的在温度等于 T 的条件下softmax输出在第 i 类上的值。

L-hard:Student模型在T=1的条件下的softmax输出和ground truth的cross entropy,具体形式为:

       c_{i}指在第 i类上的ground truth值, c_{i}∈{0,1} , 正标签取1,负标签取0。Teacher模型也有一定的错误率,使用ground truth可以有效降低错误被传播给Student模型的可能性。 \alpha\beta为权重。

二、关于温度T

       为什么要使用高温T?知识蒸馏用Teacher模型预测Soft-target来辅助Hard-target训练Student模型的方式为什么有效呢?

        softmax层的输出,除了正例之外,负标签也带有Teacher模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher模型在推理时认为该样本与该负标签有一定的相似性。而在传统的训练过程(Hard-target)中,所有负标签都被统一对待。也就是说,使用高温T的知识蒸馏的训练方式使得每个样本给Student模型带来的信息量大于传统的训练方式。

       温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签。

       总的来说, T 的选择和Student模型的大小有关,Student模型参数量比较小的时候,相对比较低的温度就可以了。因为参数量小的模型不能学到所有Teacher模型的知识,所以可以适当忽略掉一些负标签的信息。

三、代码案例

       MNIST案例代码来自参考文献【3】,Teacher模型3个全连接层每层1200个神经元,Student模型为3个全连接层每层20个神经元,蒸馏温度T= 7,

结果如下:

Teacher模型结果:

单独训练Student模型结果:

知识蒸馏训练Student模型结果:


参考文献

[1] Hinton E G ,Vinyals O ,Dean J . Distilling the Knowledge in a Neural Network.[J]. CoRR,2015,abs/1503.02531.

[2] 深度学习中的知识蒸馏技术,https://zhuanlan.zhihu.com/p/353472061

[3] 知识蒸馏:《Distilling the Knowledge in a Neural Network》算法介绍及PyTorch代码实例,https://blog.csdn.net/weixin_44808161/article/details/126083190

[4] 知识蒸馏开山之作论文精读:Distilling the knowledge in a neural network_哔哩哔哩_bilibili

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值