Distillation论文总结(2)Distilling the Knowledge in a Neural Network

文章链接:arxiv

Distilling the Knowledge in a Neural Network

很多人喜欢训一堆模型把它们综合集成一下来获得更好的性能,然而这对于实际使用来说太笨重了。以前有人研究过将一个大模型的知识如何压缩到小模型中去,我们便研究出了一套不同的压缩方法:我们在MNIST数据集上将一个集成模型蒸馏到一个小模型中,得到了不错的结果。我们同时也提出了一个新的集成方法,,由总模型和细分模型组成。不像常用的集成方法,这些模型可以快速、并行训练。

引言

以昆虫作类比、幼虫与成虫阶段的目标是不同的,就像我们训练模型在训练和应用的阶段目的也是不同的。所以训练时会喜欢训一个很大、很复杂的模型,之后再用蒸馏得到一个小模型以使其满足实际应用。蒸馏的可行性已经被证明过了(见系列前一篇)。

许多人认为模型的知识是与其参数密不可分的,从而阻碍了他们对于蒸馏的认知。实际上我们应该把模型当作是一个输入输出的映射。大模型会给每一个类别一个概率,即便这一类的可能性非常小。

我们希望训练的目标应该尽可能与实际应用目的相一致。然而,训练是拟合已有数据而实际应用则要求模型在新数据上有好的泛化性能——毕竟泛化所需要的信息是未知的。当我们蒸馏时,我们可以使小模型具有和大模型一样的泛化性能,比如学习若干个集成模型的均值,这肯定比直接用小模型去训原始数据要好。

一个简单的方法是用大模型输出的soft targets训小模型,这比原标签提供了更多的信息以及数据之间更小的分歧。从而我们可以使用更少的数据、更大的学习速率。

然而有一个问题就是许多类别的概率会过于接近0,前文提到的方法是用softmax之前的logits来学习,本文提出一种更一般化的方法叫**“蒸馏”**,通过提高softmax层的温度来使输出的softmax target不那么极端,然后再用相同的温度去训小模型。用logits训其实就是温度T无穷大时候的特例。

蒸馏时使用的数据可以全是未标定数据,也可以直接用原始训练集。我们发现用原始训练集不错。同时,如果我们额外再加一项对GT的loss会更好。

蒸馏

蒸馏,我觉得就是把温度升高,从而使有用的信息可以蒸出来。==温度越高,概率分布就越松散。==训练时大模型与小模型用同样的温度,但是测试时小模型的温度还是1。
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(zj/T)} qi=jexp(zj/T)exp(zi/T)
我们发现加入对GT的学习能得到有效提升。其中一种做法是根据GT适当修改soft targets,另外一种更好的方法是把两种Loss加权求和。注意对GT求loss时温度为1,另外由于用soft targets求出的梯度与温度有关,所以这一项需要乘以T2,这样可以保证两项之间的比例关系不会随着温度变化而变化。

Logits就是一种特例

证明暂略

当温度较高时,蒸馏与Logits方法较相近,而当温度较低时蒸馏会不注意一些概率较小的类别信息。作者认为这是一种优势,因为

  1. logits是没有被loss限制住的,所以会有噪声
  2. 小模型本来就抓不住所有的信息,所以放掉一些也没关系。

在MNIST上做实验

写了一个双隐层1200ReLU(67误)和一个双隐层800ReLU(146误)的模型,后者用前者在T=20蒸馏可得结果(74误),相当好的改善了。

实验表明,如果小模型每层有300+个神经元的话,T>8之后就没有什么区别了,但是如果是30-个的话,2.5-4范围内的T会比其他温度好。

如果我们在蒸馏时去掉所有的“3”,也只有(206误),其中“3”误了133/1010,如果将对3的训练偏重加大3.5倍则(109误)“3”误14,即在没见过“3”的情况下达到98.6%的准确率。如果只用“7”和“8”会有47.3%的错误率,但是如果把它们的偏重降低7.6倍则会将错误率降到13.2%。

在语音识别上做实验

暂略

在大数据集上训练集成模型

训练集成模型可以利用并行计算的优势,而在test时的计算量问题可以用蒸馏解决。然而,另一个大问题是训练时的大量计算与资源需求

我们认为,训练一些特制分类器去针对那些易混淆类别是可以减少计算量的。最大的问题是这些分类器很容易过拟合,我们可以用soft targets来预防过拟合。

JFT数据集

JFT是一个有着15000类、100M张图片的数据集。在此之前Google给出的baseline是一个用异步随机梯度下降法训练了6个月的模型。模型训练过程中有两重并行。首先,在不同的核上跑着很多网络的复制处理不同的mini-batch,这些复制计算完之后会将梯度传给一个共用的参数管理器,然后参数管理器再把参数返回给各个模型;其次,同一个模型中,将不同的神经元也分布在了不同的核上。

特制模型

当数据类别非常多的时候,我们可以将大模型设计成一个总模型(在所有数据上训练)和很多特制模型(主要在一些易混淆类别上训练)。通过将所有其他类归为一类,这些特制模型的softmax可以变小很多。

为了防止过拟合,这些特制模型均为总模型参数初始化。然后训练的数据一半是其特别关注的类别,另外一半从其他类别中随机选择。训练后,我们可以修正这个有偏重的训练集,通过将其他类的logit提升特别类重采样比例的对数倍(???不懂)

类别分组

那么,怎么给特制分类器分组呢?我们会选择那些经常被总模型分错的类别。虽然可以用模糊矩阵,但是有一些更简单的方法,直接对总模型的输出预测的协方差矩阵进行聚类算法,如K-means,就可以得到类别的详细分类。

Inference过程

  1. 对输入,通过总模型找到n个最有可能的类别,称为集合k。不妨取n=1。
  2. 找到所有与k有非空交集的特制模型Ak,然后求一个最终概率分布 q \bm{q} q,使其最小化 K L ( p g , q ) + ∑ m ∈ A k ( K L ( p m , q ) ) KL(\bm{p^g},\bm{q})+\sum_{m \in A_k}(KL(\bm{p^m},\bm{q})) KL(pg,q)+mAk(KL(pm,q))其中 p g \bm{p^g} pg是总分类器的输出,而 p m \bm{p^m} pm是每个特制分类器的输出。

这个式子是没有闭式解的,我们一般会将其按 q = s o f t m a x ( z ) \bm{q}=\rm{softmax}(z) q=softmax(z)参数化,然后去优化z求解q。

结果

首先,这个方法训练很快,只需要几天;其次,所有的特制分类器是独立训练的。在test过程中最后有4.4%的提升。这个模型共有61个特制分类器,每个对应300类,这些类别还不是刚好划分的所以会有重叠。事实发现重叠会提升效果。

将soft target当作正则项

一个主要的思想是soft targets承载了很多有用的信息,这点我们通过用极少数据训练来说明。我们只用前面语音识别的3%的数据。如果直接用gt训练的话会严重过拟合——在达到44.5%之后快速下降,需要early stop。而用soft targets甚至不需要early stop它会自己收敛到57%,说明了soft targets可以恢复T模型的大量信息。从而说明了它是一个合格的正则项。

用soft target来防止特制分类器过拟合

由于所有其他类被归成了一类,而我们的训练集又关注特定类,所以训练集就少了,但是我们又不能增大训练集稀释特定类。

如果我们在训练特制分类器的时候增加一个其对于总分类器的输出的学习是不是可以保障其信息的保留?这个作者表示还需要进一步探索。

Relationship to Mixtures of Experts

暂略

结论

暂略

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值