《Distilling the Knowledge in a Neural Network》阅读

https://arxiv.org/abs/1503.02531
Hinton, J.Dean, NIPS 2015 引用量-3144

摘要

一个简单的改善几乎任何机器学习算法的做法是在同样的数据集上训练很多不同的模型,然后去取它们预测结果的均值。但不幸的是,如果要给大量用户使用,那么用一个庞大的完整模型来预测是非常累赘笨拙、计算复杂度很高的,尤其是这个工业模型是一个大的神经网络。引用[1]提到可以把一个集成模型的知识压缩成一个单一的模型,更容易部署,我们发展了这个想法,用了一个不同的技术来压缩。我们在MNIST上取得了令人惊讶的效果。我们展示了我们可以通过把一个集成模型压缩到一个单一的model来很大的提升商业系统中采用的笨重的原始模型。不仅如此,我们还介绍了一种由一个或多个完整模型和多个专家模型组成的新型集成,这些模型学习如何区分完整模型混淆的细粒度类(fine-grained classes)。与直接混合专家模型不同,这些专家模型可以快速并行地进行训练。

1 介绍

在大规模机器学习中,我们总是用相似的模型训练和部署,虽然这两个过程有很大的不同:训练的时候是从一个非常大的、高度冗余的数据及上提取特征,但是这个过程不需要实时操作,可以用上巨大的计算资源。但是要部署给大量用户使用,在时间和资源上就会有更严格的需求。
为了提取数据特征的效果,我们确实需要训练一个笨重的大模型。这个大模型可能是由多个分开的模型集成的也可能是一个带有很强正则化的单一模型。当这个大模型训练完之后,我们用一个特别的训练技术,叫做”蒸馏“,把大模型知识迁移到小模型上,让它更适宜部署。
有一个概念阻止了这个方向的研究:我们总是倾向于用训练好的参数来识别训练模型中的知识,这使得我们不知道如何改变模型结构又保留了相同的知识。

对一个大模型来说,它的任务是学习从很多种类中分辨出数据的种类,使得预测为正确类别的概率最大,但也有一个潜在的影响,就是错误类别的分类概率。虽然对于正确类别来说是远远小于的,但是在错误类别之间比较它们的概率,还是有很大差别。
这些错误类别的预测概率实际上也给我们提供了这个模型生成趋势的一些信息。例如,一张宝马车的图片,会有非常小的概率被预测为一辆垃圾车,但是这个概率也比预测成一个胡萝卜大得多。

众所周知,我们训练的目标是新数据的泛化性能,而如何正确的泛化,这个指导信息无从得知。但对于小模型来说,我们对它进行知识蒸馏操作的时候就是把大模型的泛化功能传授给小模型。如果一个大模型的泛化性能足够好,比如是通过很多不同模型的集成得到,那么蒸馏学习到的小模型肯定会比直接用训练数据训练的小模型做得更好。

将大模型的泛化能力传授给小模型的一个显而易见的方法,就是将大模型产生的类的概率直接作为小模型的”软目标soft targets“。在迁移阶段,我们可以使用相同的训练集,也可以使用单独的迁移集。如果大模型是很多简单模型的集合,我们可以用它们每个各自的预测的算术、几何均值作为soft targets。soft targets都有很高的熵,相比于hard targets(labels)它们提供更多训练中的信息,训练用例的梯度的方差更小,所以小模型经常可以比大模型用更少的数据、更大的学习率。

对于MNIST这样的任务,大模型几乎总是能以非常高的置信度得出正确的答案,关于所学习函数的大部分信息都存在于软目标中极小概率的比率中。
例如,一个版本的2可能是 1 0 − 6 10^{-6} 106的概率是3, 1 0 − 9 10^{-9} 109的概率是7,而另一个版本的2可能是 1 0 − 6 10 ^{- 6} 106的概率是7, 1 0 − 6 10^{ - 6} 106的概率是3。这是有价值的信息,它定义了数据上丰富的相似结构(即,它说哪个2看起来更像3,哪个更像7),但它对传输阶段的交叉熵loss函数影响很小,因为概率非常接近于零。
Caruana和他的合作者绕过这个问题,通过使用logits(softmax层的输入)而不是将softmax产生的概率作为小模型学习的目标。通过最小化大小模型的logits之间的平方差作为目标函数。
我们的更通用的解决方案,称为“蒸馏”,是提高最终的softmax的温度T,直到大模型产生一个适当的软目标集。然后,我们使用相同的T训练小模型来匹配这些软目标。稍后我们将说明,匹配大模型的logits实际上是蒸馏的一个特殊情况。

用于小模型的迁移数据集可以完全由无标注数据组成[1]或我们可以使用原始训练集。我们发现,使用原始训练集效果好,特别是如果我们添加一个目标函数,鼓励小模型来预测真正的目标hard targets,同时匹配大模型提供的soft targets。通常,小模型不能准确地匹配soft targets,会偏离到正确的答案上,这是有帮助的。

2 蒸馏

当有真实标签时,训练蒸馏网络的时候也使用真实标签会很有效。一个做法是使用真实标签去修改soft targets,但我们想到一个更好的方式就是直接加权平均soft targets和hard targets。

我们可以先训练好一个teacher网络,然后将teacher的网络的输出结果 q q q 作为student网络的目标,训练student网络,使得student网络的结果 p p p接近 q q q ,因此,我们可以将损失函数写成
L = C E ( y , p ) + α C E ( q , p ) L=CE(y,p)+\alpha CE(q,p) L=CE(y,p)+αCE(q,p)

这里CE是交叉熵(Cross Entropy), y y y是真实标签的onehot编码, q q q是teacher网络的输出结果, p p p是student网络的输出结果。作者表示第二项的权重取小一些结果更好。

但是 q q q直接使用teacher网络的分类结果,softmax的输出,不太合适。如果直接使用输出作为 q q q,student网络很难学到类别之间的“相似度信息”,因为训练好的teacher网络在标签类上置信度很高,其他都很低,比例上来说和标签比较相似的类也并没有凸显出来。因此提出了softmax-T:
q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)} qi=jexp(zj/T)exp(zi/T)
这里 q i q_i qi即student网络去学习的对象(soft targets), z i z_i zi是teacher网络softmax前的输出logit。
T的取值为1时即普通的softmax;
T的取值为0时最大值接近1,其他值接近0,类似onehot;
T的取值越大,输出结果的分布越均匀,相当于平滑;
T的取值为正无穷时结果是均匀分布。

根据上述 L L L q q q训练student网络。
训练时student网络也用相同的T,但训练完成的前向推理用T=1。

2.1 Distillation在特殊情况下等于直接匹配logits

怎么证明?交叉熵求导,T很大时式子的近似。
q i = 1 T ( e z i T ∑ j e z i T ) q_i=\frac{1}{T}\left(\frac{e^{\frac{z_i}{T}}}{\sum_{j} e^{\frac{z_i}{T}}}\right)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值