全网最细图解知识蒸馏(涉及知识点:知识蒸馏实现代码,知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)

一.是什么?

把一个大的模型(定义为教师模型)萃取,蒸馏,把它浓缩到小的模型(定义为学生模型)。

即:大的神经网络把他的知识教给了小的神经网络。

在这里插入图片描述

二.为什么要用知识蒸馏把大模型学习到的东西迁移到小模型呢呢?

因为大的模型很臃肿,而真正落地的终端算力有限,比如手表,安防终端。
所以要把大模型变成小模型,把小模型部署到终端上。

在这里插入图片描述

2.1 轻量化网络的方向

分为下面四个方向,知识蒸馏是第一个方向。

在这里插入图片描述

三.用蒸馏温度处理学生网络的标签

学生网络有两种标签:

一种是教师网络的输出,
一种是真实的标签。

3.1 soft target

soft target使我们常用的概率版的标签值。比如输入下面的图片预测。
在这里插入图片描述
hard targets和soft targets的预测概率如下:
在这里插入图片描述
hard targets的预测结果不科学,因为马和驴比马和汽车相似的多。所以驴和汽车都是0,没有表现出这个信息,所以要用soft targets.

3.2 用教师网络预测出的soft target作为学生网络的标签。

教师网络预测出的soft target具有很多信息。

3.3 蒸馏温度

softmax有放大差异的功能。
如果值高那么一点点,经过softmax的放大就会变得很高。
如果想让soft target更加平缓,高的降低,低的升高。
这时就要对soft target使用蒸馏温度。 让soft target更soft。
实现方法是在softmax的分母处加个T。

在这里插入图片描述

效果如下:T=1时相当于没有蒸馏温度。T=3时确实低的更低高的更高了。

在这里插入图片描述
在这里插入图片描述

T和分布的关系如下图,T从1增加到10,值之间的差异越来越小,softmax的放大效果被冲淡。
当T=100的时候,结果直接变成一个横线,众生平等。

在这里插入图片描述

3.4为什么要加入蒸馏温度T让softmax的结果更平滑?

  1. 抑制过拟合: 高蒸馏温度下的软目标概率分布更平滑,相比硬目标更容忍学生模型的小误差。这有助于防止学生模型在训练过程中对教师模型的一些噪声或细微差异过度拟合,提高了模型的泛化能力。

  2. 降低标签噪声的影响: 在训练数据中存在标签噪声或不确定性时,平滑的软目标可以减少这些噪声的影响。学生模型更倾向于关注教师模型输出的分布,而不是过于依赖单一的硬目标。

  3. 提高模型鲁棒性: 平滑的软目标有助于提高模型的鲁棒性,使其对输入数据的小变化更加稳定。这对于在实际应用中面对不同环境和数据分布时的模型性能至关重要。

需要注意的是,过高的蒸馏温度也可能导致学生模型过于平滑化,失去了对数据细节的敏感性,因此需要在实践中进行调优。

四.知识蒸馏训练过程

4.1 图示知识蒸馏训练过程

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

上面是已经训练好的教师网络。
把数据输入到教师网络,在输出时使用蒸馏温度为T的softmax.
再把数据输入到学生网络,学生网络可能是还没有训练的网络,也可能是训练一半的半成品网络。  

4.2 损失函数

学生网络既要在蒸馏温度等于T时与教师网络的结果相接近。
也要保证不使用蒸馏温度时的结果与真实结果相接近。

蒸馏损失:

把教师网络使用蒸馏温度为t的输出结果 与 学生网络蒸馏温度为t的输出结果做损失。
让这个损失越小越好。

学生损失:

学生网络蒸馏温度为1(即不使用蒸馏网络)时的预测结果和真实的标签做loss.

最后对这两项加权求和。

4.3 图解损失函数计算过程

红色线条指向的是学生损失。
紫色线条指向的是蒸馏损失。

在这里插入图片描述

五.推理过程

此时学生网络已经训练好,把X输入到学生网络得到结果。
在这里插入图片描述

六.最终效果:

学生网络可以接近教师网络的识别效果,并且附加如下两个特点:

1.零样本识别

论文里面说:以手写体数字数据集为例,假如在训练学生网络时把标签为3的类别全部去掉,
但是教师网络学过3。当使用知识蒸馏将教师网络学到的东西迁移到学生网络时,学生网络虽然没有见过3,但是却能识别3,即达到了零样本的效果。

2.使用soft target训练而不是hard target,减少了过拟合

在这里插入图片描述

第二行和第三行是使用百分之3的训练样本并分别用hard target和soft target,结果显示

使用3%的训练样本 + hard target :
训练集的准确率为 67.3%, 测试集的准确率为44.5%
使用3%的训练样本 + soft target :
训练集的准确率为 65.4%, 测试集的准确率为57.5%

七.迁移学习和知识蒸馏的区别

迁移学习是把一个模型学习的领域泛化到另一个领域,比如把猫狗这些动物域迁移到医疗域。
知识蒸馏是把一个模型的知识迁移到另一个模型上。

八.知识蒸馏实现代码

https://blog.csdn.net/qq_42864343/article/details/134722507?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22134722507%22%2C%22source%22%3A%22qq_42864343%22%7D

九.参考视频

B站UP主,同济子豪兄的视频:
【精读AI论文】知识蒸馏
https://www.bilibili.com/video/BV1gS4y1k7vj/?spm_id_from=333.788&vd_source=ebc47f36e62b223817b8e0edff181613

  • 58
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
以下是使用知识蒸馏提取原型的示例代码,假设我们有一个已经训练好的分类器 `teacher_model`,我们希望从中提取出原型来训练一个新的分类器 `student_model`: ```python import torch import torch.nn as nn import torch.optim as optim # 定义原型提取器 class PrototypeExtractor(nn.Module): def __init__(self, teacher_model): super(PrototypeExtractor, self).__init__() self.teacher_model = teacher_model self.teacher_model.eval() # 确保teacher_model为评估模式 def forward(self, x): with torch.no_grad(): features = self.teacher_model.features(x) # 提取特征 outputs = self.teacher_model.classifier(features) # 计算分类概率 return features, outputs # 加载数据集和模型 train_loader, test_loader = load_data() teacher_model = load_teacher_model() # 定义原型提取器和新的分类器 proto_extractor = PrototypeExtractor(teacher_model) student_model = load_student_model() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9) # 训练新的分类器 for epoch in range(10): running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() # 提取原型和分类概率 proto_features, proto_outputs = proto_extractor(inputs) student_outputs = student_model(proto_features) # 计算损失 loss = criterion(student_outputs, labels) + criterion(proto_outputs, student_outputs) # 反向传播和优化 loss.backward() optimizer.step() running_loss += loss.item() # 输出每轮训练的损失 print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(train_loader))) ``` 在上述代码中,我们首先定义了一个 `PrototypeExtractor` 类,它接收一个已经训练好的分类器作为参数,并定义了一个 `forward` 方法用于提取原型和分类概率。在训练过程中,我们首先使用 `proto_extractor` 提取每个输入图像的原型和分类概率,然后使用 `student_model` 计算该图像的分类概率,并将原型和分类概率之间的损失添加到总损失中。最后使用反向传播和优化器更新 `student_model` 的参数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

computer_vision_chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值