Mobile-SAM使用的知识蒸馏方法论文讲解Distilling the Knowledge in a Neural Network

本文介绍了知识蒸馏的概念,即通过将大型神经网络的知识转移给小型网络。重点讨论了如何通过调整softmax中的温度来优化知识转移,以及全耦合、半耦合和解耦蒸馏三种不同的方法。作者还探讨了软目标和硬目标的优缺点,并提供了实例和温度选择的建议。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、总结

1. 简介

发表时间:2015年3月9日

论文:[1503.02531] Distilling the Knowledge in a Neural Network (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1503.02531

观看视频地址:爆肝66小时!全网最细知识蒸馏论文精讲和代码逐行讲解_哔哩哔哩_bilibiliicon-default.png?t=N7T8https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click本文是观看上述视频做的笔记,感兴趣的同学也可以去看下视频讲解哈,时间不长,半个小时左右,后半部分还包括代码讲解,想快速了解的可以参考下本篇笔记,本片笔记只包含理论知识

2. 中心思想

作者想用不同的方法将大模型压缩到小模型当中

二、知识蒸馏的定义

1. 基础

知识:通常认为,知识是模型学习到的参数 (比如卷积的权重)
蒸馏:将知识从大模型(教师网络)转移到更适合部署的小模型(学生网络)
将知识从大模型迁移到小模型,模型的结构都不同,那这些 (知识)参数怎么会迁移成功?
        例如教师网络在识别一张宝马车图片时,可能会误认为垃圾车,很小的概率误认为胡萝卜,这就隐含了概率相对大小里面的一些隐含知识,所以迁移知识是可行的。

2. 模型输出

一般模型输出包括以下三种:
logits:全连接层的输出
hard targets:logits 通过 one-hot 编码实现
soft targets:logits 通过 softmax 处理后,得到的结果

3. Soft targets的优缺点

soft targets相比于hard targets的优点:
        例如下图是一个图片分类任务,要把输入的图片识别成0~9之间的一个数字。首先纵向对比,左子图中Hard Target就只有一个概率,而Soft Target有很多其他的概率,而且其他概率中3是最高的,说明图片中的2也有点像3;同理,右子图中Hard Target同样得不出什么信息,只能得出这张图片是个2,而Soft Target还可以发现这张图片有点像7。横向对比,不同的图片,Hard Target一模一样,而Soft Target不一样。结论:Soft Target蕴含的信息相比于Hard Target更多一点,类似于前边宝马车的例子。

soft targets的缺点:
论文中的团队在做mnist任务时,也就是将一张图片识别0~9之间的数字,将2识别成3的概率比识别成7的概率大一点,但两个概率数都特别小,因此对损失函数也就是交叉熵的影响特别小,因此本文只用logits进行处理,在softmax里面引入温度。

三、知识蒸馏的温度

如下图所示,展示了本文的创新点,在公式中加入了温度。

如下图所示,在设置温度为不同值情况下,预测四种类别的概率值分布。
        从图中可以看出,当温度设置为1时, 可以看到蓝色折线比较生硬,虽然说bee的预测概率很明显,但其他类别预测的概率很低,差距较小;当温度设置为10时,bee的预测概率值降低,其他类也更加相近了,也更容易比较出bee跟什么类别更像,跟什么类别更不像;当温度设置为100时,红色折线几乎持平为平均分布;当温度设置为无穷大时,直接使用logits=直接使用softmax且引入无穷大温度,是一般知识蒸馏的一个特殊形式。
        因此温度的设置不能取大也不能取小,凭借经验,一般设置在1~20之间,当学生模型比教师模型小很多时,较低的温度效果更好。

 四、知识蒸馏的过程

1. 整体流程

如下图所示,每个颜色路线对应每个过程,具体过程如下:
(1)把数据喂到教师网络中训练,通过升温之后的softmax,得到soft targets1
(2)把数据喂到学生网络中训练,通过同温之后的softmax,得到soft targets2
(3)通过soft tatgets1和soft targets2计算得到蒸馏损失distillation loss
(4)把数据喂到学生网络中训练,通过正常的softmax,得到soft targets3
(5)通过soft tatgets3和正确标签correct label计算得到学生损失student loss

2. 两种损失

2.1 蒸馏损失

(1)输入:相同温度下,学生模型和教师模型的soft targets
(2)常用:KL散度处理
(3)作用:让学生网络的类别输出预测分布尽可能拟合教师网络输出预测分布,也就是让学生去学习老师的一些行为。

2.2 学生损失

(1)输入:T=1时,学生模型的soft targets和正确标签
(2)常用:交叉熵损失,因为图片分类常用的损失是交叉熵
(3)作用:减少教师网络中的错误信息被蒸留到学生网络中

2.3 独立的两种损失如何建立联系

         如下公式所示,通过加权平均的方法将两种损失相加。
        一般a值取0.5,当第二个权重值a小一点时,效果会更好。
        乘以T的平方是因为soft targets会产生梯度大小按照1/T的平方进行缩放,蒸馏损失中有两个soft targets,学生损失只有一个soft targets,为了保证这两者的贡献相同,因此在蒸馏损失前乘以一个T的平方,相当于给他中和掉了

五、知识蒸馏方法 

        在知识蒸馏领域,全耦合蒸馏、半耦合蒸馏和解耦蒸馏是三种不同的方法,它们在教师模型和学生模型的知识传递过程中采取不同的策略。这三种蒸馏策略各有优势和应用场景,选择哪一种策略取决于具体任务、模型的复杂性、计算资源的可用性,以及优化的目标。全耦合蒸馏提供了一种全面学习教师模型知识的方式;半耦合蒸馏在减轻计算负担和简化学习过程中找到了平衡;而解耦蒸馏则提供了最高的灵活性和细粒度控制,适用于结构或功能上有较大差异的教师和学生模型之间的知识传递。下面是对这三种方法的简要概述:

1. 全耦合蒸馏

        全耦合蒸馏指的是在知识蒸馏过程中同时优化学生模型的所有参数,包括对教师模型输出的学习和教师模型的中间层特征表示的学习。在这种方法中,学生模型尝试直接模仿教师模型的行为,包括其最终输出和内部特征表示。这种蒸馏方法可以充分利用教师模型的知识,但可能会因为教师模型和学生模型之间的结构差异而难以优化。

2. 半耦合蒸馏

        半耦合蒸馏相较于全耦合蒸馏,采取了一种更为灵活的方法。它通常只关注于学生模型的部分参数优化,可能是最终输出的直接模仿,也可能是对教师模型某些中间层特征的模仿,而不是全面模仿教师模型的每一层。这种方法的一个典型例子是冻结学生模型中的一部分参数(例如,前几层),仅对其余参数进行优化。半耦合蒸馏可以简化学习过程,减少计算负担,同时仍然能够从教师模型中获得有用的知识。

3. 解耦蒸馏

        解耦蒸馏采取了一种更为分离的方法,它将教师模型的知识传递过程分成独立的阶段或部分。首先,可能专注于从教师模型中提取关键信息或特征表示,然后在后续阶段单独优化学生模型的对应部分。这种方法允许更细粒度的控制和灵活性,可以分别针对学生模型的不同部分进行优化。解耦蒸馏可以帮助克服直接模仿教师模型时可能出现的结构或功能不匹配问题,使得学生模型能够更有效地利用教师模型的知识。

### 实现UNet模型中的知识蒸馏 #### 背景介绍 知识蒸馏是一种有效的模型压缩技术,允许小型学生网络从大型教师网络中学习。对于像UNet这样的复杂架构,在医学影像分析等领域尤为重要。通过引入知识蒸馏机制,可以在保持较高精度的同时减少计算资源消耗。 #### 方法概述 为了在UNet模型中实施知识蒸馏,主要步骤如下: - 构建一个更大更深的UNet作为教师模型; - 训练该教师模型直到收敛; - 使用相同的数据集训练较小的学生版UNet,并利用来自教师模型软标签指导损失函数的设计; - 设计合适的损失组合方式以平衡原始监督信号与教师传递的信息。 #### 示例代码实现 下面给出一段Python代码片段展示了如何构建这种师生结构以及相应的训练流程: ```python import torch.nn as nn from torchvision import models class TeacherUNet(nn.Module): def __init__(self, num_classes=1): super(TeacherUNet, self).__init__() # 定义更深层次或宽度更大的U-Net架构 pass def forward(self, x): output = ... return output class StudentUNet(nn.Module): def __init__(self, num_classes=1): super(StudentUNet, self).__init__() # 定义简化版本的U-Net架构 pass def forward(self, x): output = ... return output def distillation_loss(student_output, teacher_output, target, T=20.0, alpha=0.7): """定义混合了交叉熵和KL散度的距离损失""" loss_fn_kd = nn.KLDivLoss(reduction='batchmean') loss_fn_ce = nn.CrossEntropyLoss() kd_loss = loss_fn_kd( F.log_softmax(student_output / T, dim=-1), F.softmax(teacher_output / T, dim=-1)) * (alpha * T * T) ce_loss = loss_fn_ce(student_output, target) * (1. - alpha) total_loss = kd_loss + ce_loss return total_loss if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' # 初始化教师和学生的UNet实例 teacher_model = TeacherUNet().to(device) student_model = StudentUNet().to(device) optimizer_student = ... # 配置优化器 scheduler_student = ... # 可选配置学习率调度策略 epochs = 100 for epoch in range(epochs): train_loader = ... # 加载数据 running_loss = [] for images, labels in train_loader: inputs_teacher, targets = images.to(device), labels.to(device) with torch.no_grad(): outputs_teacher = teacher_model(inputs_teacher) outputs_student = student_model(images.to(device)) loss = distillation_loss(outputs_student, outputs_teacher, targets) optimizer_student.zero_grad() loss.backward() optimizer_student.step() running_loss.append(loss.item()) ``` 此段代码仅提供了一个框架性的描述,具体细节如`TeacherUNet`类内部的具体层设计、激活函数的选择等都需要根据实际应用场景进一步完善[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值