轻量化——知识蒸馏(KD)

知识蒸馏是一种模型压缩技术,通过训练小型网络模仿大型或集成网络来提升效率。它通过调整softmax的温度参数T,使模型能关注到更多细节信息。教师网络生成软目标,然后蒸馏这些知识到学生网络,通过结合蒸馏损失和交叉熵损失进行优化。这种方法允许学生网络学习到教师网络的泛化能力,即使在直接利用logits的情况下也能有效进行知识转移。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

知识蒸馏

动机

为了提高网络的性能,采用多个模型训练之后再加权求平均得出输出值。但是这种方法去部署的时候却不容易。针对这个问题,采用的方法有两种:

  • 模型压缩
  • 训练轻量化模型

知识蒸馏就是采用的模型压缩的方法

思想

训练一个训练好的小网络去模仿一个预先训练好的大型网络或者集成网络

其中:知识的含义是模型的参数信息保留了模型学到的知识,学习如何从输入向量映射到输出向量

例如:教师网络经过softmax层输出的结果,通常是正确的分类概率比较大;而其他的类别的概率值几乎接近0。这种结果会忽略掉其它类别的概率中包含的有用信息,没有充分利用到教师网络强大的泛化能力。
在这里插入图片描述

例如:真实标签:3,最后模型最后预测的概率发现:4的概率小于8的概率。那么其实模型也可以从这里学习到,更接近8的形状比更接近4的形状是真实标签的概率要大。

即在原始的softmax的基础上添加一个参数T(温度)使得模型能够更加关注到细节信息
在这里插入图片描述

这个表可以看出,增加蒸馏温度,能够很好的捕捉到不同类别之间的有用信息
在这里插入图片描述

方法

神经网络预测的过程

  • 输入的图片送给卷积神经网络,提取特征

  • 拉伸卷积层,送入全连接层

  • 多层全连接层得到logits Zi

  • logits Zi经过softmax得到预测概率

蒸馏的过程:

  • 教师网络训练

首先利用数据训练一个层数更深,提取能力更强的教师网络,得到logits后,利用升温Tsoftmax得到预测类别的概率分布soft targets

  • 蒸馏

蒸馏教师网络知识到学生网络,构造distillation lossstudent loss,加权相加作为最后的损失函数
L = a Lsoft + b Lhard
注:soft target 产生梯度的大小按1/T^2缩放,因此再同时使用soft targetshard targets时,蒸馏损失乘以T^2

特殊蒸馏(直接利用logits)

直接利用softmax层的输入logits(而不是输出)作为soft targets。需要最小化的目标函数时教师网络和学生网络的logits之间的平方差

  • 交叉熵求导
    在这里插入图片描述
  • T足够大时
    在这里插入图片描述
    此处使用了等价无穷小
  • 假设所有的logits对每个样本都是零均值
    在这里插入图片描述
### 目标检测中的知识蒸馏结构化蒸馏 #### 知识蒸馏概述 知识蒸馏是一种用于压缩大型复杂模型的技术,通过让小型学生模型模仿大型教师模型的行为来提高效率。在目标检测领域,这种方法被广泛应用于减少计算资源消耗并加速推理过程。具体来说,知识蒸馏可以通过对齐教师模型和学生模型之间的预测 logits 来实现[^1]。 #### 定位蒸馏 (Localization Distillation) 定位蒸馏是针对密集型目标检测的一种特定形式的知识蒸馏技术。它不仅关注分类分数的学习,还强调边界框回归任务上的指导。GFL CVPR 2022 提出了基于定位蒸馏的方法,在密集对象检测场景下表现出显著效果。该方法的核心在于引入了一种新的优化策略——从 logits 到 cell 边界的方向调整,从而提升了学生的性能表现。 #### 结构化知识蒸馏 相比传统的全局特征映射方式,结构化知识蒸馏更注重局部区域内的语义信息传递。这种机制能够更好地捕捉图像中小尺度物体的关键特性,并将其有效地迁移到轻量化的学生网络中去[^2]。例如,在某些研究工作中提到 alpha 参数可以用来调节 LookAhead 差异的比例,默认设置为 .5 被认为是一个较为理想的选择;然而实际应用过程中仍需根据具体情况做适当微调以达到最佳平衡状态。 #### 实现方法代码示例 以下是利用 PyTorch 编写的一个简单版本的目标检测知识蒸馏框架: ```python import torch.nn as nn import torch.optim as optim class TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() # Define teacher model architecture here... class StudentModel(nn.Module): def __init__(self): super(StudentModel, self).__init__() # Define student model architecture here... def knowledge_distillation_loss(student_output, teacher_output, temperature=4.0): soft_student = nn.functional.softmax(student_output / temperature, dim=-1) soft_teacher = nn.functional.softmax(teacher_output / temperature, dim=-1) loss_fn = nn.KLDivLoss(reduction='batchmean') return loss_fn(torch.log(soft_student), soft_teacher) # Initialize models and optimizer teacher_model = TeacherModel() student_model = StudentModel() optimizer = optim.Adam(student_model.parameters(), lr=0.001) for data in dataloader: inputs, labels = data with torch.no_grad(): teacher_outputs = teacher_model(inputs) student_outputs = student_model(inputs) kd_loss = knowledge_distillation_loss(student_outputs, teacher_outputs) optimizer.zero_grad() kd_loss.backward() optimizer.step() ``` 上述代码片段展示了如何构建基本的知识蒸馏流程以及定义相应的损失函数 `knowledge_distillation_loss` 。这里采用了 KL 散度作为衡量两个分布之间距离的标准之一[^3]。 #### 参考文献扩展阅读建议 对于希望深入理解此主题的研究者而言,《Knowledge Distillation for Specific Object Detectors》提供了详尽理论基础及实践指南; 同时也可以参考《整理:4篇论文知识蒸馏引领高效模型新时代》,其中涵盖了更多关于不同应用场景下的创新思路和技术细节.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值