知识蒸馏-【初识】

刚开始了解知识蒸馏这个领域,做个知识拓展学习:
参考https://www.bilibili.com/video/BV1gS4y1k7vj/?spm_id_from=333.788&vd_source=b978c891c9fe5aca965ea7a3b9a8063d
论文下载:Distilling the Knowledge in a Neural Network
论文解读
知识蒸馏即通过一个精度比较高即知识储备比较多的教师模型(当然可以多个教师),将知识迁移给学生模型,该学生模型的结构会较为轻量且高效。
在图像识别中,采用教师模型作分类,通常直接输出1 or 0;
如在识别马、驴、汽车的过程中,hard targets是在神经网络预测时,该图是马的情况,直接会输出马,即预测概率为马的值为1,对长相略像马的驴也直接会直接输出0。这种方式就是硬标签,这样的方式于我们而言是不利的。
我们理想的是一种软标签:
当图片为马时,我们预测马可能会出现一个概率为0.7,驴与马比较相似,所以概率也会略大些为0.25,但是汽车与马一点也不像这样概率为0.05。这就是软标签的形式。该方式的话会引导神经网络(student)去学习一种概率,而不是单纯的告诉他这不是汽车,这是一个马。
在这里插入图片描述
可以看下面的这个图,加深下理解:
在这里插入图片描述
hard target:熵低,信息量小;
soft target:熵高,信息量大。
但是呢,在实际使用时,由于交叉熵的代价函数的影响小,概率非常接近与0
这句话的意思是:他会让一些不太像的比如汽车这样的赋值给一个非常低的置信,所以这样就会加剧“贫富差距”,–长得更像的就会logit更高,长得一点也不像的logit就会更低了。所以这不是我们所期望的。
那么就有一些学者展开了研究:
*Caruana和他的合作者通过使用logit(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logit和小模型产生的logit之间的平方差。

理解这段话:他们的做法是采用线性分类分数直接拟合,也就是直接用student模型直接拟合teacher模型的输出分数。是采用一种回归的思路。

我们更通用的解决方案称为“蒸馏”,即提高最终softmax的温度,直到繁琐的模型产生合适的soft targets。然后我们在训练小模型时使用相同的高温来匹配这些软目标。我们稍后将说明,匹配这个繁琐模型的对数实际上是蒸馏的一种特殊情况。*
1)当teacher网络已经训练好后,student模型也可以使用相同的数据集,或者用一个更大的数据集,甚至该数据集中可以有很多未标注的数据,采用自监督的方式喂给teacher模型,得到概率值给student学习。本文中发现用原来那个相同的数据集,其结果已经是不错了。
总loss=soft loss + hard loss(soft:老师的言传身教(0-1的概率)以及学生学习到的soft+hard:课本和习题(非0即1))
在这里插入图片描述

首先输入数据,teacher先预测去估计一个soft target,然后student会拟合soft target,得到soft loss,注意此时是T=t,学生和教师是在同一个温度下进行的整理操作;
同时学生也要从课本学习到知识,此时T=1,该过程是不经过蒸馏的,得到hard loss。
接下来举例看蒸馏的过程:
首先student过程
(1)前向预测,得到logit,然后采用softmax,提供一个概率:
在这里插入图片描述

(2)此时未蒸馏,将预估结果与hard label结合,得到hard loss的结果在这里插入图片描述
(3)T=3时,此时已经蒸馏了,设置T=3,此时指数部分都要除以T,此时T=3,所以指数都是变成了除以3。该方法的目的就是为了避免贫富差距过大,比如前面那个T=1时的学生网络结果7.32×10的-7次方,接近于0了,但是该方法后该结果是0.0058。

在这里插入图片描述
这是T不同时的预估值,可以看到在T越大时,其相对结果越接近(T=100,紫色那条)。
在这里插入图片描述
teacher过程

(1)该过程与学生网络相仿,也是先get logit,然后找到T=1时的结果以及升高温度后的概率结果。在这里插入图片描述
(2)求soft loss,下面有公式。
在这里插入图片描述
整体过程可以看视频讲解。
在这里插入图片描述
下图过程与上面过程一致。首先教师网络预估出一个soft target,然后学生模型得到蒸馏后的soft loss以及没有蒸馏的hard loss,得到一个总的loss训练学生网络。T越大,则包含的知识越多。但是并不是温度越高越好的。太大可能会带来噪声,太小学习到的知识不够。
在这里插入图片描述
一些特例
在这里插入图片描述
由教师网络到学生网络是一个知识传递的过程。如在学习平移时,学生网络可能没有学习过,但是教师网络学习过,那么学生网络可能也会具备判别平移的能力(零样本学习)。

论文内容
大模型-》transfer 小模型时,是将大模型产生的类概率作为小模型的soft label,在transfer阶段可以使用相同的训练集或单独的转移数据集。
当teacher模型为较为简单的模型的集合时,也就是多个teacher教一个student时。可以使用各自预测的算术或平均值作为soft label。当soft label的熵很高时,在每个训练案例中提供的信息会比硬目标提供的多得多,在训练间的梯度方差也会小得多。因此小模型通常可以用比原始大模型训练时使用更少的数据进行训练,同时可以赋值给更高的学习率。

大模型用于像MNIST这样的任务时,通常具备更高的置信度,在transfer过程中,由于交叉熵的代价函数的影响小,概率非常接近与0。
caruana与其合作者,采用logit(softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这一问题。最小化大模型产生的logit和小模型产生的logit之间的平方差。

本文采用蒸馏方法,即提高softmax的温度,直到大模型产生合适的soft label集。在训练小模型时使用相同的温度匹配这些soft label。

2.蒸馏

公式中,通常经过logit后,会采用softmax产生每个类别的概率结果,将每个类别得到的logit结果进行比较,得到概率qi加粗样式,zi是logit分数。在这里插入图片描述
T=1时就是softmax
在这里插入图片描述
首先教师网络通过蒸馏方法得到一个soft target,学生网络得到两个结果(1)采用相同的温度蒸馏,跟教师网络进行拟合,得到soft loss;(2)不经过蒸馏,和hard label进行拟合,得到hard loss。最终的total loss是由这两部分组成的。
以上为训练阶段。在预测阶段时,直接通过student模型得到一个预测结果。T越大,teacher中的soft target越软,所提供的知识越多。

在最简单的蒸馏形式中,知识被转移到蒸馏模型上,方法是在一个转移集中训练它,并对转移集中的每个情况使用软目标分布,该转移集中由使用softmax中具有高温的繁琐模型生成。训练蒸馏模型时使用相同的高温,但训练后它使用的温度为1(预测阶段 )

当所有或部分传输集都知道正确的标签时,通过训练蒸馏模型生成正确的标签,可以显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是软目标的交叉熵,这个交叉熵是使用蒸馏模型的softmax中的相同高温计算的,就像从麻烦的模型生成软目标时使用的一样。第二个目标函数是带有正确标签的交叉熵。这是使用蒸馏模型的softmax中完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用一个可调节的较低权重通常可以获得最佳的结果。由于软目标产生的梯度大小为1/ t2,因此在使用硬目标和软目标时,重要的是将其乘以t2。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生改变,则硬目标和软目标的相对贡献大致保持不变。

2.1 直接让学生网络拟合教师网络输出的logit是知识蒸馏的一个特例

在这里插入图片描述
当蒸馏温度足够大事,知识蒸馏就等价于最小化均方误差。在较低的温度下,蒸馏对比平均值负得多的logit的匹配关注要少得多。这是潜在的优势,因为这些logit几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常消极的logit可能会传递关于繁琐模型所获得的知识的有用信息。这些效应中哪一种占主导地位是一个经验问题。我们表明,当蒸馏模型太小,无法捕获繁琐模型中的所有知识时,中间温度工作得最好,这强烈表明忽略大的负对数是有帮助的。

在这里插入图片描述

3 MNIST为例子

采用教师网络有60000个训练案例,两个隐含层,每个有1200个relu激活函数神经元,并赋予了Dropout正则化。
学生网络有两个隐含层,每个有800个relu激活函数神经元,没有Dropout。
未进行知识蒸馏:学生犯了146个错误;
使用T=20时:学生犯了74个错误。
当student有300个甚至更多的神经元时,T=8及以上的温度,结果相似。当student有30个事,2.5-4的温度效果更佳。
零样本学习:教师模型学习过平移的知识,学生模型没有学过,经过蒸馏后,学生也会有平移的能力。
同时尝试了删除样本数字为3的样本示例。对student而言,数字3从未见过。经过蒸馏后发现有206个错误,其中133个是在数字3上犯错(3共有1010个图片)。这是因为3这个数字的偏置太小了,而数字3,student从未见过,我们尝试将偏置改成3.5,现在有109个错误,其中14个是在3上犯错误。
蒸馏可以学习小样本学习零样本学习等等。
transfer中不包含哪类,学生网络中哪类的偏置低。

4 应用于语音识别的案例

比较直接用数据集训练一个小网络,以及直接用知识蒸馏蒸馏出一个小网络,发现采用知识蒸馏的方法效果更佳。

目前语音识别领域中的这个ASR系统是采用DNN模型,将音频中的声波提取时间上下文的关系,映射到隐马尔可夫的离散状态。
语音转文字的过程应该既符合声音的特征,又符合语言文本的特征。
在这里插入图片描述
该过程就是输入声学的特征,来找隐马尔可夫的状态。
在这里插入图片描述
实验中验证了采用baseline、10个模型集成,以及知识蒸馏后的学生模型的预测结果,可以看到采用10个模型作集成学习时效果最好。这10个模型的特点是:好而不同,这10个模型都采用了随机初始化操作,以保证每个模型都是各不相同的,同时在同一个训练集上训练,得到模型。
但是从表中也可以看到,单个的经过知识蒸馏后的学生模型其精度比baseline这个大模型的效果还要好,同时比集成模型在精度上略低,WER持平。

5 Training ensembles of specialists on very big datasets

训练模型的集合是利用并行计算的一种非常简单的方法,通常集成学习在测试时需要太多的计算,这一点可以通过使用蒸馏来解决。然而,对于集成还有另一个缺点:如果单个模型是大型神经网络,数据集非常大,训练时所需的计算量过多,他就不容易并行化了。

每一个模型都训练起来十分庞大,那应该怎样并行?
想法:每一个模型都关注不同的点,成为不同的专家模型,这样每个专家只需要在一小部分的数据中进行训练即可。也就是建立专家模型,将多个不用领域的专家模型集合,但是这种专注于细粒度区分的专家模型很容易就过拟合了,为此我们描述了如何通过使用软目标来防止这种过拟合。

5.1 The JFT dataset

谷歌内部数据集,未公开。
JFT是一个内部的谷歌数据集,它有1亿张带有15,000个标签的标记图像。当我们做这项工作时,谷歌的JFT基线模型是一个深度卷积神经网络[7],它已经在海量算力资源中使用异步随机梯度下降训练了大约6个月。本训练使用了两种类型的并行[2]。首先,有许多神经网络的副本运行在不同的核集上,并处理来自训练集的不同小批次。每个副本计算当前小批处理上的平均梯度,并将该梯度发送到一个分片参数服务器,该服务器发回参数的新值。这些新值反映了自上次向副本发送参数以来参数服务器接收到的所有梯度。其次,每个副本通过在每个核上放置不同的神经元子集而分布在多个核上。集成训练是第三种可包装的并行

在这里插入图片描述
FitNet是feature-based的知识蒸馏方法。分为两个阶段:
1.以student(比教师网络:深而薄的网络)学习teacher(深而宽的网络),需要对齐维度(所以相较于logit-based方法有一定的时间消耗)。
2.采用hinton中的KD方法蒸馏。
在这里插入图片描述

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值