因为直接训练一个小的网络,往往结果就是没有从大的网络剪枝好。知识蒸馏的概念是 一样的,因为直接训练一个小的网络,没有小的网络根据大的网络来学习结果要来得 好。
因而,先训练一个 大的网络,这个大的网络在知识蒸馏里面称为教师网络(teachernetwork),其是老师。我们 要训练的是真正想要的小的网络,即学生网络(studentnetwork)。先训练一个大的网络称为 教师网络。再根据这个大的网络来制造学生网络。在网络剪枝里面,直接把那个大的网络做一 些修剪,把大的网络里面一些参数拿掉,就把它变成小的网络。在知识蒸馏里面是不一样的, 这个小的网络(学生网络)是去根据教师网络来学习。假设要做手写数字识别,就把训练数据 都丢到教师里面,教师就产生输出,因为这是一个分类的问题,所以教师的输出其实是一个分 布。 比如教师的输出可能是看到这张图片1的分数是0.7,7的分数是0.2,9 这个数字的分 数是0.1 等等。接下来给学生一模一样的图片,但是学生不是去看这个图片的正确答案来学 习,它把老师的输出就当做正确答案,也就是老师输出1要0.7,7要0.2,9要0.1。学生的 输出也就要尽量去逼近老师的输出,尽量去逼近1是0.7、7是0.2、9是0.1这样的答案。学 生就是根据老师的答案学,就算老师的答案是错的,学生就去学一个错的东西。
其实知识蒸馏也不是新的技术,知识蒸馏最知名的一篇文章Hinton在15年的时候已经 发表论文了。很多人会觉得知识蒸馏是Hinton提出来的,因为Hinton有一篇论文“Distilling the Knowledge in a Neural Network”。但其实在 Hinton 提出知识蒸馏这个概念之前,其实就 有看过其他文章使用了一模一样的概念。举例来说,论文“Do Deep Nets Really Need to be Deep”是一篇 13 年的文章里面,也提出了网络蒸馏的想法。
为什么知识蒸馏会有帮助呢?一个比较直觉的解释是教师网络会提供学生网络额外的信 息,如图1 所示,如果直接跟学生网络这是1,可能太难了。因为1可能跟其他的数字有 点像,比如1跟7也有点像,1跟9也长得有点像,所以对学生网络,我们告诉它:看到这张 图片我们要输出1。7、9的分数都要是0,可能很难,它可能学不起来,所以让它直接去跟 老师学,老师会告诉它这是1。我们没有办法让它是1分,也没有关系。其实1跟7是有点 像的,老师都分不出1跟7的差别。老师说1是0.7,7是0.2,学生只要学到1是0.7,7是 0.2 就够了。这样反而可以让小的网络,学得比直接从头开始训练,直接根据正确的答案要学 来得要好。
图1 知识蒸馏
Hinton 论文里面甚至可以做到教师告诉学生哪些数字之间有什么样的关系这件事情,就 可以让学生在完全没有看到某些数字的训练数据下,就可以把那一个数字学会。假设训练数 据里面完全没有数字7,但是教师在学的时候有看过数字7,但是学生从来没有看过数字7 。但光是凭着教师告诉学生说1跟7有点像,7跟9有点像这样子的信息,都有机会让学生 可以学到7长什么样子。就算它在训练的时候,从来没有看过7的训练数据。这是知识蒸馏 的基本概念。
教师网络不一定要是单一的巨大网络,它甚至可以是多个网络的集成,训练多个模型,输 出的结果就是多个模型,投票的结果就结束了。或者是把多个模型的输出平均起来的结果当 做是最终的答案。虽然在比赛里面,常常会使用到集成的方法。如果在一个机器学习的比赛排 行榜里面要名列前茅,往往凭借的就是集成技术,就是训练多个模型,把那么多的模型的结果 通通平均起来。但是在实用上,集成会遇到的问题就训练了1000个模型,进来一笔数据,我 们要1000 个模型都跑过,再取它的平均,计算量也未免太大了。打比赛还勉强可以。要用在 实际的系统上显然是不行的,可以把多个集成起来的网络综合起来变成一个,如图2所示。
图2 使用网络集成作为教师网络的输出
这个就要用知识蒸馏的做法,就把多个网络集成起来的结果当做是教师网络的输出。让学生 网络去学集成的结果,让学生网络去学集成的输出,让学生网络去逼进一堆网络集成起来的 正确率。
在使用知识蒸馏的时候有一个小技巧。这个小技巧是稍微改一下 Softmax 函数,会在 Softmax 函数上面加一个温度(temperature)。Softmax 要做的事情就是把每一个神经元的输 出都取指数,再做归一化,得到最终网络的输出,如下式 所示。网络的输出变成一个概 率的分布,网络最终的输出都是介于0到1之间的。所谓温度,就是在做取指数之前,把每 一个数值都除上T,如下式 所示。
其中T 是一个需要调整的超参数。假设T >1,温度T 的作用就是把本来比较集中的分 布变得比较平滑一点。举个例子,如下式所示,假设y1,y2,y3 是原始的值,y′ 1,y′ 2,y′ 3 是 Softmax 后的值,softmax 后的值都趋近于 0。
假设教师网络的输出如上式所示,让学生要叫教师网络去跟这个结果学,跟直接和 正确的答案学完全没有不同。跟教师学的一个好处就是,老师会告诉我们说哪些类别其实是 比较像的,让学生网络在学的时候不会那么辛苦。但是假设老师的输出非常地集中,其中某 一个类别是1,其他都是0。这样子跟正确答案学没有不同,所以要取一个温度。假设温度T 设为100,如下式所示,对于教师,加上温度,分类的结果是不会变的。做完Softmax 以后, 最高分的还是最高分,最低分的还是最低分。所有类别的排序是不会变的,分类的结果是完全 不会变的。但好处是每一个类别得到的分数会比较平滑、比较平均,我们拿这一个结果去给学 生学才有意义,才能够把学生学得好。这是知识蒸馏的一个小技巧。
温度太大,模型会会改变很多。假设温度接近无穷大,这样所有的类别的分数就变得差不 多,学生网络也学不到东西了,因此T 又是另外一个超参数,它就跟学习率一样,这个是我 们在做知识蒸馏的时候要调的参数。