文章目录
1、hard targets 和 soft targets
- 比如在识别马,驴,车的分类任务中,需要对这三个类别打标签,比如对一头马打标签,对于hard targets 就是马:1,驴:0,车:0,而对于soft targets就是马:0.7,驴:0,25,车:0.05
- 显然soft targets的标签更具有科学性,说明了该对象有多像马,有多不像马,所以soft targets可以传递更多的信息,
- 在知识蒸馏中,一般用hard targets作为ground truth训练教师网络,教师网络预测的结果作为soft targets来训练学生网络,
- soft targets包含了更多“知识”和‘“信息’,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)|
2、蒸馏温度 T
蒸馏温度 T是调节正确类别和非正确类别的概率的相对大小
公式如下:
q i = e x p ( Z i / T ) ∑ i e x p ( Z j / T ) q_{i}=\frac{exp(Z_{i}/ T)}{\sum_{i}exp(Z_{j}/ T)} qi=∑iexp(Zj/T)exp(Zi/T)
T越大,则正确类别的概率和非正确类别的概率的相对差值就越小,如下图,
下图是bili up 同济子豪兄的图示:
3、知识蒸馏的过程
- 首先需要一个提前训练好的教师网络,数据喂给教师网络,蒸馏温度为T,一般参数T<20,输出soft labels,如下图第一行,学生网络可以是训练一定epoch的,也可以是没有训练的,然后同样的数据喂给学生网络,蒸馏温度同样为T,输出为soft predictions,然后教师网络输出的soft labels和学生网络输出的soft predictions做损失函数,目的是蒸馏温度同样为T,学生网络要接近教师网络,
- 接着,同样的数据喂给学生网络,蒸馏温度为1,输出的是hard predictions,hard predictions和hard labels做损失函数,
- 最后给2个损失函数调整合适的权重作为网络的总的损失函数,由于 L d i s L_{dis} Ldis的梯度大约是 L s t u L_{stu} Lstu的 1 T 2 \frac{1}{T^2} T21,因此可以在 L d i s L_{dis} Ldis前乘上 T 2 T^2 T2,可以保证两个损失部分的梯度量贡献基本一致,
也可以参考同济子豪兄绘制的图,
这是另一个比较好的图示,
3.1 代码片段
4、知识蒸馏的对比实验
其中Baseline是学生网络,10xEnsemble是教师网络,Distilled Single model是教师网络蒸馏学生网络,
可以看出使用Soft Targets可以明显防止过拟合,
4.1 另一个对比试验
5、知识蒸馏的应用场景
- 模型压缩
- 优化训练,防止过拟合(潜在的正则化)
- 无限大、无监督数据集的数据挖掘
- 少样本、零样本学习
- 迁移学习和知识蒸馏,迁移学习和知识蒸馏的概念是正交的,迁移学习指的是领域的迁移,知识蒸馏是模型之间的蒸馏,
6、soft targets VS labels smoothing
soft targets是教师模型预测输出的,所以每一个类别的概率是不一样的,而labels smoothing是除了正确类别的概率最高外,其他类别的概率都是一样的(非零常数),这是为了防止正确类别的概率过高,但是不能看出除了最高概率外的类别的差异性
如下图,
参考资料:
1、知识蒸馏综述:CSDN链接
2、yolov5的蒸馏、剪枝:github链接
3、CSDN链接,【目标检测】YOLOv5遇上知识蒸馏:CSDN链接
4、另一个yolov5的蒸馏:github链接
5、同济子豪兄哔站视频讲解:哔站链接
6、哔站up Enzo_Mi,https://www.bilibili.com/video/BV1yN411T7jY/?spm_id_from=333.337.search-card.all.click&vd_source=5435cc6fd84f94da0fa6d26fc8a94994