Distilling the Knowledge in a Neural Network
论文地址:https://arxiv.org/pdf/1503.02531v1
动机
这篇论文的研究动机是探索一种有效的方法,将多个模型的知识压缩到一个单一的模型中,以便于更容易地部署和使用。同时,研究团队想探索一种新的模型组合方式,并在语音识别和MNIST等任务中实现更好的性能。
贡献
这篇论文提出了一种称为"知识蒸馏"的方法,可以将一个由多个神经网络组成的集合压缩成一个单一的模型,以便更轻松地部署。作者们通过实验表明,这种方法可以显著提高MNIST数据集和自动语音识别系统的性能。此外,他们还提出了一种新型的集合方法,包括一个或多个完整模型和许多专家模型,这些模型学习区分全模型混淆的细粒度类别。作者们的方法可以在大型神经网络上获得更好的性能,同时还可以并行地训练多个专家模型。因此,这篇论文的贡献在于提出了一种有效的方法来压缩神经网络集合,并提高了神经网络的性能。
摘要
提高几乎任何机器学习算法性能的一种非常简单的方法是在同一数据上训练许多不同的模型,然后对它们的预测进行平均。不幸的是,使用整个模型集合进行预测很麻烦,并且可能计算成本太高,无法部署到大量用户,尤其是当单个模型是大型神经网络时。我们在 MNIST 上取得了一些令人惊讶的结果,我们表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进大量使用的商业系统的声学模型。还引入了一种新类型的集成,由一个或多个完整模型和许多专家模型组成,这些模型学习区分完整模型混淆的细粒度类。与专家的混合不同,这些专业模型可以快速并行训练。
简介
许多昆虫都有一种幼虫形式,它能从环境中提取能量和营养物质,而一种完全不同的成虫形式,能满足非常不同的旅行和繁殖需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,并且可以使用大量的计算。然而,部署到大量用户对延迟和计算资源的要求更加严格。与昆虫的类比表明,如果能更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用dropout等非常强的正则化器训练的单个非常大的模型。一旦繁琐的模型得到训练,我们就可以使用另一种训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型中。
一个概念上的障碍可能阻碍了对这种非常有前途的方法进行更多的研究,那就是我们倾向于用学到的参数值来识别训练过的模型中的知识,这使得我们很难看到如何能够改变模型的形式而保持相同的知识。对知识的一种更抽象的看法是,它是一种从输入向量到输出向量的学习映射,这使它摆脱了任何特定的实例化。对于学习区分大量类的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是,训练过的模型为所有错误答案分配概率,即使这些概率非常小,其中一些也比其他的大得多。例如,一辆宝马的图像可能只有很小的几率被误认为是一辆垃圾车,但这种错误的可能性仍然比误以为是一根胡萝卜的可能性高很多倍。
人们普遍认为,用于训练的目标函数应该尽可能地反映用户的真实目标。尽管如此,模型的训练通常是为了优化训练数据的性能,而真正的目标是对新数据进行良好的泛化。训练模型能够很好地泛化显然会更好,但这需要有关泛化的正确方法的信息,并且这些信息通常不可用。然而,当我们将知识从一个大模型提炼成一个小模型时,我们可以训练小模型以与大模型相同的方式进行泛化。如果繁琐的模型泛化得很好,例如,它是不同模型的大型集合的平均值,那么以相同方式训练泛化的小模型在测试数据上的表现通常会比在用于训练集合的相同训练集上以正常方式训练的小模型好得多。
将繁琐模型的泛化能力转移到小模型的一个明显方法是将繁琐模型产生的类概率作为训练小模型的 “软目标”。对于这个转移阶段,我们可以使用相同的训练集或一个单独的 "转移 "集。当繁琐的模型是由较简单的模型组成的大集合时,我们可以使用它们各自预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,它们在每个训练案例中提供的信息要比硬目标多得多,并且在训练案例之间的梯度中方差要小得多,因此小模型通常可以在比原始繁琐模型少得多的数据上进行训练,并且使用更高的学习率。
对于像MNIST这样的任务,繁琐的模型几乎总是以非常高的置信度产生正确的答案,关于所学函数的大部分信息都存在于软目标中非常小的概率的比率。例如,一个版本的2可能有10-6的概率是3,10-9的概率是7,而另一个版本可能是相反的。这是很有价值的信息,它在数据上定义了丰富的相似性结构(即它说哪些2看起来像3,哪些看起来像7),但它对转移阶段的交叉熵成本函数影响很小,因为概率非常接近于零。Caruana 和他的合作者通过使用 logits(最终 softmax 的输入)而不是 softmax 产生的概率作为学习小模型的目标来规避这个问题,并且它们最小化了繁琐模型产生的 logits 与小模型产生的 logits 之间的平方差异。我们更普遍的解决方案,称为 “蒸馏”,就是提高最终softmax的温度,直到繁琐的模型产生一套合适的软目标。然后我们在训练小模型时使用同样的高温来匹配这些软目标。我们在后面说明,匹配繁琐模型的对数实际上是蒸馏法的一个特例。
用于训练小模型的传输集可以完全由未标记的数据组成,或者我们可以使用原始训练集。我们发现使用原始训练集效果很好,尤其是当我们在目标函数中添加一个小项来鼓励小模型预测真实目标并匹配繁琐模型提供的软目标时。通常,小模型不能完全匹配软目标并在正确答案的方向出错是有帮助的。
Distillation
神经网络通常通过使用 "softmax "输出层来产生类别概率,该输出层通过比较
z
i
z_i
zi和其他对数,将为每个类别计算的对数
z
i
z_i
zi转换成概率
q
i
q_i
qi。
q
i
=
e
x
p
(
z
i
/
T
)
∑
j
e
x
p
(
z
j
/
T
)
q_i=\frac{exp(z_i/T)}{\sum_j exp(z_j/T)}
qi=∑jexp(zj/T)exp(zi/T)
其中 T 是通常设置为 1 的温度。对于 T 使用更高的值会在类上产生更柔和的概率分布。
在最简单的蒸馏形式中,知识被转移到蒸馏模型中,方法是在转移集上训练它,并对转移集中的每个案例使用软目标分布,这个软目标分布是通过使用繁琐的模型,在其softmax中使用高温产生的。在训练蒸馏模型时,也使用同样的高温,但在它被训练后,它使用的温度为1。
当所有或部分转移集的正确标签都是已知的,这种方法可以通过同时训练蒸馏模型来产生正确的标签而得到显著的改善。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均。第一个目标函数是与软目标的交叉熵,这个交叉熵的计算是使用蒸馏模型的softmax中相同的高温,就像从繁琐的模型中生成软目标那样。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用适当低的权重通常可以获得最佳结果。由于软目标产生的梯度大小为 1 / T 2 1/ T^2 1/T2,因此在使用硬目标和软目标时,将其乘以 T 2 T^2 T2是很重要的。这确保了在进行元参数实验时,如果改变用于蒸馏的温度,硬目标和软目标的相对贡献大体上保持不变。
1、Matching logits is a special case of distillation
转移集中的每个案例都贡献了一个交叉熵梯度,
d
C
/
d
z
i
dC/dz_i
dC/dzi,对于每个logit,蒸馏模型的
z
i
z_i
zi。如果繁琐模型的logits
v
i
v_i
vi产生软目标概率
p
i
p_i
pi,并且迁移训练在温度T下进行,则该梯度为:
∂
C
∂
z
i
=
1
T
(
q
i
−
p
i
)
=
1
T
(
e
z
i
/
T
∑
j
e
z
j
/
T
−
e
v
i
/
T
∑
j
e
v
j
/
T
)
\frac{\partial C}{\partial z_i}=\frac{1}{T}\left(q_i-p_i\right)=\frac{1}{T}\left(\frac{e^{z_i/T}}{\sum_j e^{z_j/T}}-\frac{e^{v_i/T}}{\sum_j e^{v_j/T}}\right)
∂zi∂C=T1(qi−pi)=T1(∑jezj/Tezi/T−∑jevj/Tevi/T)
如果温度比对数的大小高,我们可以近似:
∂
C
∂
z
i
≈
1
T
(
1
+
z
i
/
T
N
+
∑
j
z
j
/
T
−
1
+
v
i
/
T
N
+
∑
j
v
j
/
T
)
\frac{\partial C}{\partial z_i}\approx\frac{1}{T}\left(\frac{1+z_i/T}{N+\sum_j z_j/T}-\frac{1+v_i/T}{N+\sum_j v_j/T}\right)
∂zi∂C≈T1(N+∑jzj/T1+zi/T−N+∑jvj/T1+vi/T)
如果我们现在假设对数已经为每个转移情况分别进行了零均值化,使得
∑
j
z
j
=
∑
j
v
j
=
0
\sum_j z_j = \sum_j v_j = 0
∑jzj=∑jvj=0,公式3简化为:
∂
C
∂
z
i
≈
1
N
T
2
(
z
i
−
v
i
)
\frac{\partial C}{\partial z_i}\approx\frac{1}{NT^2}\left(z_i-v_i\right)
∂zi∂C≈NT21(zi−vi)
因此,在高温极限下,蒸馏等价于最小化
1
/
2
(
z
i
−
v
i
)
2
1/2(z_i−v_i)^2
1/2(zi−vi)2,前提是每个传输情况的对数分别被零均值。在较低的温度下,蒸馏对那些比平均水平负得多的对数值的关注要少得多。这有潜在的优势,因为这些对数几乎完全不受用于训练繁琐模型的成本函数的约束,所以它们可能是非常嘈杂的。另一方面,非常负的 logits 可能会传达有关繁琐模型获得的知识的有用信息。这些影响中的哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法在繁琐的模型中捕获所有知识时,中间温度效果最好,这表明忽略大的负 logits 可能会有所帮助。