论文题目:Distilling the Knowledge in a Neural Network
思想总结:
深度神经网络对信息的提取有着很强的能力,可以从大量的数据中学习到有用的知识,比如学习如何将手写数字图片进行0~9的分类。
层数越多(越深),神经单元个数越多的网络,可以在大量的数据中获取的知识越丰富,能力越强。
然而当我们使用一个十分复杂的网络对一个较大的训练集进行训练时,参数众多,网络模型复杂,计算成本太高而无法部署到大量用户。
那么我们是否可以使用一种方法,将复杂的网络获取的知识,转移(提取)到一个相对较小,较简单的网络中呢?使得他们对相同的问题有着相同甚至更优的泛化能力呢?
从而,大大降低计算成本,并使其可以大量部署。
如何将一个已经训练好的、较大、较复杂的网络模型所学习到的知识,提取到另外一个较小、较简单的网络模型中,并使得它们对测试数据集有着相同的泛化能力。
如上图所示,对于一个数字0~4的5个类别的分类问题,我们如何将上面复杂的网络经过训练后学习到的知识,转移到下面较为简单的网络中。
我们使用同样的数据集训练下面较为简单的网络。
(1)要想两个网络对同一问题有相同的能力,首先我们要确保的是对于同一个输入x而言,它们输出的类别结果是一样的。
对于上面的两个网络,由于他们都使用softmax函数,所以要确保他们输出的最大概率的类别号相同。
然而,保证(1)就可以了吗?我们来看看这样一些现象。
1:对于像MNIST这样的任务,复杂的模型几乎总是以非常高的置信度产生正确的答案,模型学习到的很多信息存在于概率非常小的比率中。例如,一种类型的2,被分类为3的概率为10^-6 ,被分类为7的概率为10^−9,而另一个类型的2可能相反。这些不同点是有价值的信息,它定义了数据上丰富的相似结构(说明了哪些2像3而哪些2更像7),但它对传递阶段的交叉熵代价函数影响很小,因为概率非常接近于零。
2:在图片分类问题中,将一辆宝马错误的分类为拖拉机的概率远大于将其分类为胡萝卜的概率。
这些现象说明,当我们进行分类问题时,正确的类别所得到的概率是有用的,并且其他错误的类别所得到的概率同样蕴含这有用的信息,也是网络通过训练而学习到的信息。
所以,在将复杂网络学习到的知识转移到简单网络时,不仅仅要学习正确的类别上的概率,还要学习错误类别上的概率。因为错误类别上的概率同样体现了模型的泛化方式,和所学习到的知识。虽然在错误的类别上的概率较小,但是在有一些错误类别上的概率要比另外一些错误类别的概率大很多,所以这种大小关系任然体现了复杂模型所学习到的一些知识。
(2)因此,在简单网络中,我们不仅要使得正确类别的概率最大,并且要保证其他类别的概率与复杂网络中得到的概率相同。即P1~P5都要相同。
我们将此作为简单模型的目标进行训练,就可以使得知识较好的转移。
同理,我们也可以将一个复杂的网络所学习到的知识,转移到多个简单的网络中,使得每个简单的网络获得复杂网络的一个子功能。
比如:
复杂网络学习到了分类数字0~9的知识。我们可以将其转移到3个简单网络中,使得:
网络1:可以对1~4进行分类
网络2:可以对4~7进行分类
网络3:可以对7~9进行分类
再结合网络1,2,3即可完成复杂网络的所有功能,并且大大降低计算量。