知识蒸馏基础


Logits Distillation

参考论文:https://arxiv.org/abs/1503.02531
知识蒸馏框架
损失计算:
损失计算
以分类问题举例
Soft-Target:关于分类结果的概率输出
Hard-Target:GT的类别,是一个独热编码
Loss:

  • Hard Loss,即在Student Model的输出和Hard Target之间计算交叉熵损失CE
  • Soft Loss,在Teacher Model和Student Model进行SoftMax操作之前先进行升温操作(除以一个不为1的常数T),将SoftMax操作之后的输出进行均方误差损失MSE(也可以是交叉熵损失CE)

T o t a l L o s s = λ × S o f t L o s s + ( 1 − λ ) × H a r d L o s s Total Loss = \lambda \times Soft Loss + (1 - \lambda ) \times Hard Loss TotalLoss=λ×SoftLoss+(1λ)×HardLoss
升温操作:
在这里插入图片描述

  • T = 1 时,就是和原始模型的概率分布输出一致
  • 0 < T < 1时,输出的概率分布和原始模型的输出相似
  • T > 1时,输出的概率分布比原始模型的输出更平缓
  • 随着T的增加,Softmax 的输出分布越来越平缓,信息熵会越来越大。温度越高,softmax上各个值的分布就越平均,思考极端情况,当 ,此时softmax的值是平均分布的。

Feature Distillation

让Student Model学习Teacher Model的中间层输出
在这里插入图片描述
目标:把“宽”且“深”的网络蒸馏成“瘦”且“更深”的网络。
在这里插入图片描述
主要分成两个阶段:

  1. Hints Training:Teacher的1-N层为 W h i n t W_{hint} Whint ,第N层输出为Hint;Student的1-M层为 W g u i d e d W_{guided } Wguided,第M层输出为Guided。Hint和Guided的维度可能不匹配,使用一个卷积适配器r将Guided映射到和Hint的维度匹配,最小化下面这个Loss:
    在这里插入图片描述
    算法流程伪代码:
    在这里插入图片描述

Distillation Scheme

在这里插入图片描述

Offline Distillation

离线蒸馏是最传统的知识蒸馏方法。在这个方案中,首先独立地训练一个大的、复杂的模型(教师模型)。教师模型通常会在数据集上进行充分的训练,直到达到很高的精度。一旦教师模型被训练好,它的输出(通常是分类任务中的软标签)就被用来指导小的模型(学生模型)的训练。学生模型的训练是在教师模型固定之后进行的,因此称之为“离线”蒸馏。学生模型试图模仿教师模型的输出,以此来达到比自己直接在数据集上训练更好的性能。

Online Distillation

在线蒸馏指的是教师模型和学生模型同步训练的情况。在这个方案中,不需要预先训练一个固定的教师模型,而是让教师和学生在同一时间内相互学习。有时候,教师模型会在训练过程中动态更新,即学生模型的训练过程中教师模型也在不断地学习和改进。这种方法允许学生模型从教师模型的即时反馈中学习,而教师模型也可以根据学生模型的进展进行调整。在线蒸馏使得模型训练更加灵活,因为它不需要一个训练完成的教师模型作为起点。

Self Distillation

自蒸馏是一个相对较新的概念,它不涉及两个不同的模型,而是在同一个模型上进行迭代训练。一个模型首先被训练,然后它自己的输出被用作下一轮训练的目标。简单来说,模型首先以一定方式训练(例如使用硬标签),一旦完成,模型的预测(软标签)被用来再次训练模型。这个过程可以重复多次,每次模型都尝试模仿自己先前版本的输出。自蒸馏允许模型在没有外部教师模型的情况下提高其性能。这种方案的优势是简单和成本低。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值