知识蒸馏(Knowledge Distillation)

本文探讨知识蒸馏在深度学习领域的应用,包括模型压缩、迁移学习和多教师信息融合。通过知识蒸馏,学生网络可以从教师网络中学习到软目标、特征表示和解决方案流程等知识,提高推理效率。介绍了多种知识蒸馏方法,如Distillation、Attention Transfer、Ensemble学习等,并探讨了对抗样本支持的决策边界知识迁移、自监督学习与知识蒸馏的结合等前沿研究。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文主要罗列与知识蒸馏相关的一些算法与应用。但首先需要明确的是,教师网络或给定的预训练模型中包含哪些可迁移的知识?基于常见的深度学习任务,可迁移知识列举为:

  • 中间层特征:浅层特征注重纹理细节,深层特征注重抽象语义;
  • 任务相关知识:如分类概率分布,目标检测涉及的实例语义、位置回归信息等;
  • 表征相关知识:强调特征表征能力的迁移,相对通用、任务无关(Task-agnostic);表征间相关性,如相似度、Relation等;

另外,知识蒸馏的应用主要有哪些?粗略概况,可包含模型压缩、迁移学习与多教师信息融合等。

1、Distilling the Knowledge in a Neural Network

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。

如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做Softmax计算,可以获得软化的概率分布(软目标或软标签),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,可以用One-hot矢量表示。Total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。

教师网络与学生网络也可以联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络Softmax输出的交叉熵loss、学生网络Softmax输出的交叉熵loss、以及教师网络数值输出与学生网络Softmax输出的交叉熵loss):

联合训练的Paper地址:https://arxiv.org/abs/1711.05852

2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions

GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch

这篇文章将Total loss重新定义如下:

Total loss的PyTorch代码如下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将Teacher network的预测输出缓存到CPU内存中,可以减轻GPU显存的Overhead:

def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                             F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

3、Ensemble of Multiple Teachers

Paper地址: Efficient Knowledge Distillation from an Ensemble of Teachers | Request PDF

第一种算法:多个教师网络输出的Soft label按加权组合,构成统一的Soft label,然后指导学生网络的训练:

第二种算法:由于加权平均方式会弱化、平滑多个教师网络的预测结果,因此可以随机选择某个教师网络的Soft label作为Guidance:

第三种算法:同样地,为避免加权平均带来的平滑效果,首先采用教师网络输出的Soft label重新标注样本、增广数据、再用于模型训练,该方法能够让模型学会从更多视角观察同一样本数据的不同功能:

4、Hint-based Knowledge Transfer

Paper地址:htt

评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值