知识蒸馏
动机
为了提高网络的性能,采用多个模型训练之后再加权求平均得出输出值。但是这种方法去部署的时候却不容易。针对这个问题,采用的方法有两种:
- 模型压缩
- 训练轻量化模型
知识蒸馏就是采用的模型压缩的方法
思想
训练一个训练好的小网络去模仿一个预先训练好的大型网络或者集成网络
其中:知识的含义是模型的参数信息保留了模型学到的知识,学习如何从输入向量映射到输出向量
例如:教师网络经过softmax
层输出的结果,通常是正确的分类概率比较大;而其他的类别的概率值几乎接近0
。这种结果会忽略掉其它类别的概率中包含的有用信息,没有充分利用到教师网络强大的泛化能力。
例如:真实标签:3
,最后模型最后预测的概率发现:4
的概率小于8
的概率。那么其实模型也可以从这里学习到,更接近8
的形状比更接近4
的形状是真实标签的概率要大。
即在原始的softmax
的基础上添加一个参数T
(温度)使得模型能够更加关注到细节信息
这个表可以看出,增加蒸馏温度,能够很好的捕捉到不同类别之间的有用信息
方法
神经网络预测的过程
-
输入的图片送给卷积神经网络,提取特征
-
拉伸卷积层,送入全连接层
-
多层全连接层得到
logits Zi
-
logits Zi
经过softmax
得到预测概率
蒸馏的过程:
- 教师网络训练
首先利用数据训练一个层数更深,提取能力更强的教师网络,得到logits
后,利用升温T
的softmax
得到预测类别的概率分布soft targets
- 蒸馏
蒸馏教师网络知识到学生网络,构造distillation loss
和student loss
,加权相加作为最后的损失函数
L = a Lsoft + b Lhard
注:soft target
产生梯度的大小按1/T^2
缩放,因此再同时使用soft targets
和hard targets
时,蒸馏损失乘以T^2
特殊蒸馏(直接利用logits)
直接利用softmax
层的输入logits
(而不是输出)作为soft targets
。需要最小化的目标函数时教师网络和学生网络的logits
之间的平方差
- 交叉熵求导
- 当
T
足够大时
此处使用了等价无穷小 - 假设所有的logits对每个样本都是零均值