[DL]深度学习_知识蒸馏

【精读AI论文】知识蒸馏

目录

一、概述

二、知识蒸馏

1、引言

2、知识的表示与迁移

3、蒸馏温度T

4、知识蒸馏过程

5、实验结果 

6、知识蒸馏应用场景 

三、知识蒸馏和迁移学习的关系 

1、定义和目标

2、数据来源

3、目标任务

4、应用领域


一、概述

        知识蒸馏网络是一种神经网络模型,用于将复杂模型的知识转移到简化的模型,以提高模型的性能和效率。它借鉴了知识蒸馏方法的思想,通过将复杂模型的知识传递给简化模型来实现模型压缩和加速。

        知识蒸馏网络的基本原理是通过引入一个教师模型(通常是一个复杂的模型,例如深度神经网络)来指导一个学生模型(通常是一个简化的模型,例如浅层神经网络)的训练。教师模型在训练过程中可以提供额外的信息,例如类别概率分布、中间层特征表示等,以帮助学生模型更好地学习。将知识从教师网络迁移到学生网络中。

        在知识蒸馏网络中,通常使用两个损失函数来进行训练:硬性目标损失和软性目标损失。硬性目标损失是传统的监督学习损失,用于指导学生模型在训练数据上进行准确的预测,例如交叉熵损失。软性目标损失是根据教师模型的输出和学生模型的输出之间的差异来定义的,用于捕捉教师模型的知识。常见的软性目标损失包括均方误差损失和KL散度损失。

知识蒸馏网络的优点包括:

  • 模型压缩和加速:通过将复杂模型的知识转移到简化模型,可以减少模型的参数数量和计算复杂度,从而实现模型的压缩和加速。
  • 泛化能力改进:由于教师模型具有较强的泛化能力,通过知识蒸馏可以帮助学生模型更好地泛化到未见过的数据上。
  • 对抗对抗攻击:知识蒸馏可以提高模型的抗对抗攻击性能,使模型更加稳健和可靠。
  • 超参数调节:知识蒸馏可以帮助调节学生模型的超参数,例如学习率和正则化参数等,以提高模型的性能和鲁棒性。

        然而,知识蒸馏网络也存在一些挑战和限制,例如如何选择合适的教师模型、如何设计有效的损失函数以及如何平衡硬性目标损失和软性目标损失等。

二、知识蒸馏

1、引言

        “在机器学习领域中,往往训练和部署用同一套模型,训练阶段和部署阶段的目标是不同的,训练阶段的目标是从数据集中学习到海量的规律和特征,而部署阶段的目标之一是足够的快,足够轻量化,占用资源尽量的少,拥有较高效率。”

2、知识的表示与迁移

         在针对分类网络中,若输入数据为一幅马的图像,在训练阶段,使用Hard targets进行训练,只告诉模型该输入图片为一匹马,则此时马的可能性为1,其余标签的可能性为0。

        但是在实际应用场景下,Hard targets具有不合理的地方。使用Hard targets相当于告诉网络该图像中包含的是一匹马,不是驴也不是汽车,此时不是驴和不是汽车的概率是相等的,均为0,但是实际应用场景中明显输入图像马和驴相较于汽车来说是有很多相似性的,应该是更像驴,而更不像汽车。

        但是若是在预测阶段,将马的图像输入到一个已经训练好的分类模型中,模型会给出Soft targets样式的概率结果,预测为马的概率最大,但是预测为驴的概率也比汽车的概率要大。

        在蒸馏框架中,训练教师网络时可以使用Hard targets,教师网络在预测时候输出的是Soft targets,这样就可以传递出更多的信息,利用教师网络输出的Soft targets来训练学生网络。

Hard targets

        硬目标(Hard Targets)是指传统的监督学习中使用的目标损失函数。它基于真实标签与模型预测之间的差异来度量模型的性能。常见的硬目标损失函数包括均方误差损失(Mean Squared Error,MSE)和交叉熵损失(Cross Entropy Loss)等。硬目标损失函数用于指导学生模型在训练数据上进行准确的预测。学生模型努力通过最小化硬目标损失函数来逼近教师模型的性能。

Soft targets

        软目标(Soft Targets)是在知识蒸馏网络中引入的一种额外的目标损失函数。它通过对比教师模型和学生模型之间的输出概率分布差异来度量模型之间的相似性。常用的软目标损失函数包括均方误差损失(Mean Squared Error,MSE)和KL散度损失(Kullback-Leibler Divergence Loss)等。软目标损失函数帮助学生模型更好地学习教师模型的知识,使得学生模型可以获得更多的信息或细节,比如模型在不确定样本上的输出概率分布。

优势特点

  • 通过使用软目标损失函数,学生模型可以从教师模型的输出中获得更多的信息,从而更好地学习复杂模型的知识。
  • 软目标为学生模型提供了一种辅助训练的方式,使得学生模型能够接近或超越教师模型的性能。
  • 软目标还可以帮助学生模型在未见过的数据上更好地泛化,提高模型的鲁棒性。
  • 在知识蒸馏网络中,通常使用硬目标和软目标损失函数的加权和来平衡两者的作用。权重的选择可以根据具体的任务和需求进行调整,以达到最佳的性能和效果。
  • 硬目标损失函数用于指导学生模型在训练数据上的准确预测,而软目标损失函数通过比较教师模型和学生模型之间的输出概率分布差异来传递教师模型的知识。
  • 这两种目标损失函数的结合可以帮助学生模型获得更好的性能和泛化能力。

3、蒸馏温度T

        使用教师网络预测出的Soft targets作为训练学生网络的标签,但是还需要细化Soft targets类别间的差异,利于学生网络解耦出非正确类别间的信息,引入蒸馏温度T,这个蒸馏温度越高,类间差异越细化。

        蒸馏温度通常被表示为一个正数,通常为1或更大。它用于调整教师模型和学生模型之间的概率分布差异的程度。较高的蒸馏温度会使教师模型的概率分布更加平滑,从而使得学生模型在训练过程中更容易学习到教师模型的知识。

        通过增加蒸馏温度,软目标损失函数中的目标分布变得更加模糊,学生模型将更加关注教师模型的大概率预测。这有助于学生模型更好地学习到教师模型的知识,包括模型在不确定样本上的输出概率分布。

        在原先的Softmax操作中,将公式中e的指数次方除以T,若T=1,则等同于原始Softmax函数。若T值大于1,则原始较为显著的差异概率分布就变得较为平滑。T越高,非正确类别的概率的相对大小信息暴露更加明显。较为平滑的Soft targets包含更多非正确类别间的差异信息,使学生网络更好地学习到所有类别间的特征。

4、知识蒸馏过程

  1. 有一个已经训练好的教师网络;
  2. 将数据输入给教师网络,计算出每个数据在温度T时刻的Softmax;
  3. 将数据输入给学生网络,该学生网络还未开始训练或还没训练完成。也预测出每个数据在温度T时刻的Softmax;
  4. 计算在温度T下,教师网络和学生网络分别计算出的Softmax之间的损失,目的是让教师网络和学生网络之间的损失函数越小越好,即学上网络在模仿教师网络的Soft预测结果的过程;
  5. 学生网络还计算一个T=1的,即普通的Softmax预测,与Hard label(Ground Truth)计算损失函数,目的是让Hard预测与真实值更接近;
  6. 学生网络既要兼顾温度为T时候的预测结果要与教师网络尽可能接近,也要兼顾T=1时,预测结果要与标准答案更接近,总损失函数由Distillation loss(Soft loss)和Student loss(Hard loss)加权求和。

notes: 

  • Distillation loss相当于有一个老师在手把手教,告诉这是一个马,更像驴,更不像车,交学生驴和车与马之间的相似性;
  • Hard loss相当于有一个课本,课本上有一幅插图是马,插图下边写着这就是马,不是别的东西。
  • 整个蒸馏网络的目标就是微调学生网络的参数权重,使得最终Total loss最小化。

5、实验结果 

少样本/零样本效果

        在论文中,在MNIST手写识别数据集上进行了预实验,发现知识蒸馏架构有附带的效果:

        在训练学生网络时,将数据集中类别为3的所有数据剔除,但是训练教师网络时利用完整数据集进行训练。在训练结束后,会发现教师网络将知识迁移给学生网络之后,学生网络在遇到没有训练过的3这个类别的数据时,也会依靠教师网络传递给的知识正确识别出来。

        即知识蒸馏框架可以实现少样本甚至零样本的训练。

防止过拟合

        当用非常少的数据集(3%)去训练一个网络时:

  • 使用Hard targets训练时,在训练集上准确率高,但是在测试集上准确率低,造成过拟合现象;
  • 使用Soft targets训练时,仍然保持了一定的效果。

6、知识蒸馏应用场景 

  • 模型压缩
  • 优化训练,防止过拟合(潜在的正则化)
  • 无限大、无监督数据集的数据挖掘
  • 少样本、零样本学习 

三、知识蒸馏和迁移学习的关系 

1、定义和目标

  • 知识蒸馏:知识蒸馏是一种将一个复杂模型中的知识转移给一个轻量级模型的方法。通过这种方式,轻量级模型可以学习到与原模型相似的知识和能力,但具有更高的效率和更小的模型大小。
  • 迁移学习:迁移学习是通过将已经学习到的知识或模型应用到新的任务或领域中,以加速学习过程或提高学习性能的方法。迁移学习关注的是如何将一个任务的知识迁移到另一个任务上,以提高目标任务的性能。

2、数据来源

  • 知识蒸馏:知识蒸馏通常涉及到两个模型:一个是教师模型(复杂模型),另一个是学生模型(轻量级模型)。教师模型通过训练数据得到,学生模型通常是一个较简单的模型。
  • 迁移学习:迁移学习可以基于已有数据来训练模型,也可以基于预训练模型来进行模型的迁移。预训练模型可以是在大规模数据上通过自监督学习或者其他任务训练得到。

3、目标任务

  • 知识蒸馏:知识蒸馏的目标是通过将复杂模型中的知识迁移到轻量级模型上,使得轻量级模型能够达到与复杂模型相似的性能。通常,知识蒸馏的目标是提高轻量级模型的泛化性能和效率。
  • 迁移学习:迁移学习的目标是通过迁移已有任务的知识或模型来提升目标任务的性能。迁移学习的目标是加速学习过程、提高模型的泛化能力或解决数据稀缺的问题。

4、应用领域

  • 知识蒸馏:知识蒸馏在模型压缩和加速方面有广泛的应用。例如,在移动设备上部署机器学习模型时,可以使用知识蒸馏技术将复杂模型转化为轻量级模型,从而在保持性能的同时减少计算和存储资源的开销。
  • 迁移学习:迁移学习在很多领域都有应用,尤其是当目标任务的数据较少或者在新的领域中,可以通过迁移已有的知识或模型来提升性能。例如,在自然语言处理领域,可以使用预训练的语言模型来迁移学习到下游任务上。
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

IAz-

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值