为了提高模型准确率,我们习惯用复杂的模型(网络层次深、参数量大),甚至会选用多个模型集成的模型,这就导致我们需要大量的计算资源以及庞大的数据集去支撑这个“大”模型。但是,在部署服务时,就会发现这种“大”模型推理速度慢,耗费内存/显存高,这时候我们又会想念“小”模型的好。那么,有没有一种方法能够尽可能继承大模型的泛化能力,又像小模型一样轻量级呢?今天来介绍一种模型压缩的方法——蒸馏(Distillation)。
传统的蒸馏
首次提出知识蒸馏压缩模型思想的是2006年Bucilua,但是论文里没有实际工作阐述:https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf。
所以,一般认为最早是Hinton在2015年提出并应用在了分类任务上:Distilling the Knowledge in a Neural Network。我们来阐述一下传统的知识蒸馏过程:简单地说,就是先用数据集训练一个效果非常好的Teacher模型,然后选择一个较为轻量级的Student模型,同时接受数据集和来自Teacher模型给予的Knowledge Transfer的“知识”来训练这个轻量级Student模型。那么整个蒸馏的过程中,我们主要关心的就是Teacher模型的选择、Student模型的选择、以及Student模型的训练过程(或者说是Knowledge Transfer过程)。
Teacher模型:首先,我们需要一个原始的“大”模型——Teacher模型,这个模型可以不限制其结构、参数量、是否集成,要求这个模型尽可能精度高,并且对于给定的输入X可以给出输出的监督信息Y,这个Y在分类任务中就是softmax的结果,也就是输出对应类别的概率值。这里我们称Y为soft targets,而训练数据的标注好的标签,我们称为hard targets。
Student模型:这个部分的模型选择会有很多限制,要求其参数量小,结构相对简单,当然最好是单模型。并且需要注意的是,训练过程中student模型学习的不再是单纯的hard targets(标注好的真实标签),而是融入teacher模型输出的soft targets(监督信息Y),这里也被称为knowledge transfer。蒸馏的损失函数distillation loss分为两部分:一部分计算teacher和student之间输出预测值的差别(student预测的y 和 soft targets),另一部分计算student原本的loss(student预测的y 和 hard targets),这两部分做凸组合作为整个模型训练的损失函数来进行梯度更新,最终获得一个同时兼顾精度和性能的student模型。
这里单独说一下teacher和student之间输出预测值的loss,这个部分被做的文章也是比较多,这实际上是两个分布的距离问题,可以选择传统的Cross,也可以选择MSE、KL散度等,在博主的实验里发现对不同的student模型,适合不同的loss函数,这里只能自己多做尝试。
为什么蒸馏会有效?
那么,肯定有人想问,为什么蒸馏会有效?直接从数据集学习不是更为直观没有中间商赚差价吗?本质上,蒸馏的训练方式主要是改变了模型只能单一地学习label的这个缺陷。原