Distilling the Knowledge in a Neural Network

Distilling the Knowledge in a Neural Network

Abstract

提高几乎任何机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用一整套模型进行预测是很麻烦的,而且计算成本可能太高,不允许部署到大量用户,特别是在单个模型是大型神经网络的情况下。Caruana和他的合作者[1]已经表明,将集成中的知识压缩到一个更易于部署的单一模型中是可能的,并且我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,并且我们表明,通过将模型集合中的知识提取到单个模型中,可以显著地改进大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专家模型组成的新型集成,这些模型学习如何区分完整模型混淆的细粒度类。与专家的混合不同,这些专家模型可以快速并行地进行训练

1 Introduction

许多昆虫都有一种幼虫形态,可以从环境中提取能量和营养,而另一种完全不同的成虫形态则可以满足不同的旅行和繁殖需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别这样的任务,训练必须从非常大、高度冗余的数据集中提取结构,但它不需要实时操作,而且它可以使用大量的计算量。然而,部署到大量用户对延迟和计算资源有更严格的要求。与昆虫的类比表明,如果能够更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。笨重的模型可以是一个单独训练的模型集合,也可以是一个非常强大的正则化器训练的非常大的模型,比如dropout[9]。一旦累赘模型得到训练,我们就可以使用另一种训练,我们称之为“蒸馏”,将知识从累赘模型转移到更适合部署的小模型。Rich Caruana和他的合作者已经开创了这种策略的一个版本[1]。在他们的重要论文中,他们令人信服地证明,由大量模型集合获得的知识可以转移到单个小模型。

一个概念块可能阻止了对这种非常有前途的方法的更多研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何改变模型的形式,但保持相同的知识。知识的一个更抽象的观点是,它是一个从输入向量到输出向量的学习映射,可以使知识从任何特定的实例化中解放出来。对于学习区分大量分类的繁琐模型,通常的训练目标是最大化正确答案的平均对数概率,但这种学习的一个副作用是,经过训练的模型会给所有不正确的答案分配概率,即使这些概率非常小,其中一些也会比另一些大得多。不正确答案的相对概率告诉我们很多关于繁琐模型如何趋向于推广的信息。举个例子,宝马的形象被误认为是垃圾车的几率可能很小,但这个错误的可能性仍然比误认为是胡萝卜的可能性大很多倍。

人们普遍认为,用于训练的目标函数应尽可能地反映用户的真实目标。尽管如此,当实际目标是很好地推广到新的数据时,模型通常被训练来优化训练数据的性能。显然,最好是训练模型很好地进行泛化,但这需要有关正确泛化方式的信息,而这些信息通常是不可用的。然而,当我们从一个大模型中提取知识到一个小模型中时,我们可以像训练大模型一样训练小模型进行泛化。如果笨重的模型能够很好地推广,比如说,它是一个大的不同模型集合的平均值,那么一个小的模型,经过同样的方法的推广训练,在测试数据上通常会比一个小的模型,在训练集合时用同样的训练方法进行常规的训练要好得多。

将笨重模型的泛化能力转化为小模型的一个明显方法是利用笨重模型产生的类概率作为训练小模型的“软目标”。对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集,当笨重的模型是一个简单模型的大集合时,我们可以使用它们各自预测分布的算术或几何平均值作为软目标。当软目标具有较高的熵时,相对于硬目标,软目标在每个训练案例中提供的信息量更多,训练案例之间的梯度变化也更小,因此小模型往往比原始的繁琐模型能够在更少的数据上进行训练,并且使用更高的学习率。

对于MNIST这样的任务,繁琐的模型几乎总是能以非常高的置信度得出正确的答案,关于所学习函数的大部分信息都存在于软目标中极小概率的比率中。例如,一个版本的2可能有 1 0 − 6 10^{-6} 106的概率是3和 1 0 − 9 10^{-9} 109的概率为7,而对于另一个版本,可能是另一种方式。这些有价值的信息定义了数据上丰富的相似结构(也就是说哪个2看起来像3,哪个看起来像7)但是它对传递阶段的交叉熵代价函数影响很小,因为概率非常接近于零。Caruana和他的合作者通过使用logits(最终softmax的输入)而不是由softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logits和小模型产生的logits之间的平方差。我们更普遍的解决方案,称为“蒸馏”,是提高最终softmax的温度,直到笨重的模型产生一组合适的软目标。然后我们在训练小模型时使用相同的高温来匹配这些软目标。我们稍后将展示,匹配繁琐模型的logits实际上是蒸馏的一个特例。

用于训练小模型的传输集可以完全由未标记的数据组成[1],或者我们可以使用原始的训练集。我们发现使用原始训练集效果很好,特别是当我们在目标函数中加入一个小项,鼓励小模型预测真实目标,并与笨重模型提供的软目标匹配时,通常小模型不能与软目标精确匹配,并且会朝着正确的答案是有帮助的。

2 Distillation

神经网络通常通过使用“softmax”输出层来生成类概率,该输出层通过将每个类的logit,zi与其他logit进行比较,将其转换为概率qi。

q i = e x p ( z i / T ) Σ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\Sigma_jexp(z_j/T)} qi=Σjexp(zj/T)exp(zi/T) (1)

其中T是通常设置为1的温度。对T使用更高的值会在类上产生更柔和的概率分布。

在最简单的蒸馏形式中,通过在传递(transfer)集上对知识进行训练,并对传递集中的每个样本使用软目标分布,从而将知识传递到蒸馏模型中,该软目标分布是使用其softmax中具有高温的繁琐模型生成的。训练蒸馏模型时使用相同的高温,但训练后使用的温度为1。

当所有或部分传输集都知道正确的标签时,还可以通过训练蒸馏模型来生成正确的标签来显著改进此方法。一种方法是使用正确的标签来修改软目标,但是我们发现更好的方法是简单地使用两个不同目标函数的加权平均。第一个目标函数是与软目标的交叉熵,该交叉熵是使用蒸馏模型的softmax中的相同高温计算的,该高温与从繁琐模型生成软目标的温度相同。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用较低的权重通常可以获得最佳结果。由于软目标产生的梯度大小为 1 / T 2 1/T^2 1/T2,因此在同时使用硬目标和软目标时,将其乘以 T 2 T^2 T2是很重要的。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生变化,则硬目标和软目标的相对贡献大致保持不变。

2.1 Matching logits is a special case of distillation

传递集中的每一个样本都有一个交叉熵梯度, d C d z i \frac{dC}{dz_i} dzidC,相对应于蒸馏模型的每一个logit, z i z_i zi。如果笨重的模型具有产生软目标概率 p i p_i pi的logits v i v_i vi,并且转移训练是在T的温度下进行的,则该梯度由下式给出:

在这里插入图片描述

如果温度比对数的量级大小高,我们可以近似:

在这里插入图片描述
如果我们假设logits对于每个transfer case是均值为0 的,那么有: Σ j z j = Σ j v j = 0 \Sigma_jz_j=\Sigma_jv_j=0 Σjzj=Σjvj=0,公式3可简化为:

在这里插入图片描述

所以在高温限制下,蒸馏等于最小化 1 / 2 ( z i − v i ) 2 1/2(z_i-v_i)^2 1/2(zivi)2,*:*假设每个转移案例的logit均值分别为零。在较低的温度下,蒸馏对比平均值负得多的对数匹配的关注要少得多。这可能是有利的,因为这些logtis几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂(very noisy)。另一方面,非常negative的logits可以传达关于由繁琐模型获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获笨重模型中的所有知识时,中间温度的效果最好,这强烈表明ignoring the large negative logits是有帮助的。

3 Preliminary experiments on MNIST

为了观察蒸馏的效果,我们在60000个训练案例中训练了一个包含1200个校正线性隐藏单元(修正线性隐藏单位)的两个隐藏层的大型神经网络。如[5]所述,使用dropout和权重约束对网络进行了强正则化。Dropout可以被看作是一种训练一个指数级的共享权重的模型集合的方法。此外,输入图像在任何方向上都会受到最多两个像素的抖动。这个网络实现了67个测试误差,而一个较小的网络,有两个隐层,800个校正线性隐藏单元(rectified linear hidden units),没有正则化,实现了146个误差。但是,如果仅仅通过添加一个额外任务匹配大网络再温度为20时的soft target来修正网络,达到了74的误差。这说明软目标可以将大量的知识转移到提取的模型中,包括如何从translated训练数据中学习到的知识,即使转移集不包含任何translated。

当蒸馏网在其两个隐藏层中的每一个中有 300 个或更多单位时,所有高于 8 的温度都给出了相当相似的结果。但是当这从根本上减少到每层 30 个单位时,2.5 到 4 的温度比更高或更低的温度工作得更好。

然后,我们尝试从传输集中省略所有数字3的例子。因此,从经过提炼的模型的角度来看,3是一个虚构的数字,它从未见过。尽管如此,经过蒸馏的模型只产生了206个测试错误,其中133个是在测试集中1010个的3的分类样本中。大多数的错误是由于3类的学习偏差太低造成的。如果这个偏差增加了3.5(这优化了测试集的整体性能),那么经过蒸馏的模型会产生109个错误,其中14个错误出现在3这个分类中。因此,在正确的偏差下,尽管在训练中从未见过3,但经过提炼的模型得到了98.6%的测试3的正确率。如果转移集只包含训练集的类别7和类别8,则提取的模型产生47.3%的测试误差,但是当7和8的偏差降低7.6以优化测试性能时,测试误差降至13.2%。

5 Training ensembles of specialists on very big datasets

训练一个模型集合是利用并行计算的一种非常简单的方法,通常的反对意见是,一个集合在测试时需要太多的计算,可以用蒸馏法来处理。然而,对集合的另一个重要反对意见是:如果单个模型是大型神经网络,并且数据集非常大,那么在训练时所需的计算量是 过多的,即使很容易并行化。

在这一节中,我们给出了这样一个数据集的示例,并展示了学习专家模型(每个模型关注不同的类的可混淆子集)如何减少学习集成所需的总计算量。专注于精细区分的专家的主要问题是它们很容易过度拟合,我们描述了如何通过使用软目标来防止这种过度拟合。

5.1 The JFT dataset

JFT是一个内部Google数据集,有1亿个标签图像,15000个标签。当我们做这项工作时,Google的JFT基线模型是一个深卷积神经网络[7],它已经使用异步随机梯度下降对大量核心进行了大约6个月的训练。这个训练使用了两种并行性[2]。首先,有许多神经网络的副本运行在不同的核心集上,并处理来自训练集的不同小批量。每个副本计算其当前小批量上的平均渐变,并将此渐变发送到分片参数服务器,该服务器将返回参数的新值。这些新值反映了参数服务器自上次向副本发送参数以来接收到的所有渐变。其次,通过在每个核上放置不同的神经元子集,每个复制品分布在多个核上。集成训练是第三种类型的并行性,它可以封装在其他两种类型上,但前提是有更多的核心可用。等待数年来训练一组模型不是一种选择,因此我们需要一种更快的方法来改进基线模型。

5.2 Specialist Models

当类的数量非常大时,笨重的模型应该是一个集合,包含一个基于所有数据训练的通用模型和许多“专家”模型,这些专家模型基于高相似度的子类进行训练,这些数据在类的一个非常容易混淆的子集(如不同类型的蘑菇)。这类专家的softmax可以通过将它不关心的所有类合并到一个垃圾类中而变得更小。为了减少过度拟合和共享学习低级特征检测器的工作,每个专家模型都用广义模型的权值进行初始化。然后,通过训练专家来稍微修改这些权重,专家的一半示例来自其特殊子集,另一半则从训练集的其余部分随机抽样。训练后,我们可以通过增加垃圾箱类的对数来纠正有偏见的训练集,该对数是专业类过度采样的比例。

5.3 Assigning classes to specialists

为了给专家们导出对象类别的分组,我们决定把重点放在我们整个网络经常混淆的类别上。尽管我们可以计算混淆矩阵并将其用作查找此类集群的方法,但我们选择了一种更简单的方法,不需要真正的标签来构造集群。

特别是,我们将聚类算法应用于通用模型,这样一组经常一起预测的类 S m S^m Sm将被用作目标对于我们的一个专家模型m,我们将K-均值算法的在线版本应用于列的协方差矩阵,并得到合理的聚类(如表2所示)。我们试过了几种聚类算法产生了相似的结果。

在这里插入图片描述

5.4 Performing inference with ensembles of specialists

在调查专家模型被提炼出来后会发生什么之前,我们想看看包含专家的集成表现得如何。除了专家模型之外,我们总是有一个通用模型,这样我们就可以处理没有专家的类,从而可以决定使用哪些专家。给定一个输入图像x,我们分两步进行top-one分类:

步骤1:对于每一个测试用例,我们根据一般模型找到n个最可能类。将这组类称为k。在我们的实验中,我们使用n=1。

步骤2:然后我们取所有专家模型m,其特殊的可混淆类子集 S m S^m Sm与k有一个非空的交集,并将其称为活动专家集 A k A_k Ak(注意,该集可能为空)。然后,我们找到所有类上的全概率分布q,这些类最小化:

在这里插入图片描述

其中KL为KL散度, p m , p g p^m,p^g pm,pg表示专家模型或通才全模型的概率分布。 p m p^m pm是一个分布在所有的专家类的m加上一个单一的垃圾箱类,所以当计算它与全q分布的KL散度时,我们求全q分布分配给m的垃圾箱中所有类的所有概率之和。

等式5没有一般的闭式解,尽管当所有的模型为每一类产生一个单一的概率时,解要么是算术平均值,要么是几何平均值,这取决于我们是使用KL(p,q)还是KL(q,p))。我们参数化q = softmax(z) (T = 1),并使用梯度下降法优化logits z w.r.t eq. 5。请注意,必须对每个图像执行此优化。

results

从训练有素的基线全网络开始,专家们训练得非常快(几天而不是几周)。而且,所有的专家都是完全独立训练的。表3显示了基线系统和基线系统以及专家模型的绝对测试精度。61个专业模型的测试精度总体上提高了4.4%。我们还报告了条件测试的准确性,这是通过只考虑属于专家类的示例,并将我们的预测限制在该类的子集的准确性。

在我们的JFT专家实验中,我们培训了61个专家模型,每个模型有300个类(加上垃圾箱类)。因为专家的类集合不是不相交的,所以我们经常有多个专家覆盖一个特定的图像类。表4显示了测试集示例的数量、使用专家时在位置1处正确的示例数量的变化,以及按涵盖该类的专家数量细分的JFT数据集top1精度的相对提高百分比。当我们有更多的专家覆盖一个特定的类时,精确度的提高会更大,这是一个总的趋势,这让我们感到鼓舞,因为训练独立的专家模型非常容易并行化。

6 Soft Targets as Regularizers

多的专家覆盖一个特定的类时,精确度的提高会更大,这是一个总的趋势,这让我们感到鼓舞,因为训练独立的专家模型非常容易并行化。

6 Soft Targets as Regularizers

我们关于使用软目标而不是硬目标的一个主要主张是,在软目标中可以携带很多有用的信息,而软目标不可能用一个硬目标编码。在本节中,我们通过使用少得多的数据来拟合前面描述的基线语音模型的85M参数来证明这是一个非常大的效果。表5显示,只有3%的数据(大约2000万个例子),用硬目标训练基线模型会导致严重的过度拟合(我们提前停止了,因为在达到44.5%后精度急剧下降),而用软目标训练的同一模型能够恢复整个训练集中几乎所有的信息(约2%shy)。更值得注意的是,我们不必提前停止:软目标系统只是“收敛”到57%。这表明,软目标是一种非常有效的方法,可以将一个基于所有数据训练的模型发现的规律传递给另一个模型。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值