文章目录
Logits Distillation
参考论文:https://arxiv.org/abs/1503.02531
损失计算:
以分类问题举例
Soft-Target:关于分类结果的概率输出
Hard-Target:GT的类别,是一个独热编码
Loss:
- Hard Loss,即在Student Model的输出和Hard Target之间计算交叉熵损失CE
- Soft Loss,在Teacher Model和Student Model进行SoftMax操作之前先进行升温操作(除以一个不为1的常数T),将SoftMax操作之后的输出进行均方误差损失MSE(也可以是交叉熵损失CE)
T
o
t
a
l
L
o
s
s
=
λ
×
S
o
f
t
L
o
s
s
+
(
1
−
λ
)
×
H
a
r
d
L
o
s
s
Total Loss = \lambda \times Soft Loss + (1 - \lambda ) \times Hard Loss
TotalLoss=λ×SoftLoss+(1−λ)×HardLoss
升温操作:
- T = 1 时,就是和原始模型的概率分布输出一致
- 0 < T < 1时,输出的概率分布和原始模型的输出相似
- T > 1时,输出的概率分布比原始模型的输出更平缓
- 随着T的增加,Softmax 的输出分布越来越平缓,信息熵会越来越大。温度越高,softmax上各个值的分布就越平均,思考极端情况,当 ,此时softmax的值是平均分布的。
Feature Distillation
让Student Model学习Teacher Model的中间层输出
目标:把“宽”且“深”的网络蒸馏成“瘦”且“更深”的网络。
主要分成两个阶段:
- Hints Training:Teacher的1-N层为
W
h
i
n
t
W_{hint}
Whint ,第N层输出为Hint;Student的1-M层为
W
g
u
i
d
e
d
W_{guided }
Wguided,第M层输出为Guided。Hint和Guided的维度可能不匹配,使用一个卷积适配器r将Guided映射到和Hint的维度匹配,最小化下面这个Loss:
算法流程伪代码:
Distillation Scheme
Offline Distillation
离线蒸馏是最传统的知识蒸馏方法。在这个方案中,首先独立地训练一个大的、复杂的模型(教师模型)。教师模型通常会在数据集上进行充分的训练,直到达到很高的精度。一旦教师模型被训练好,它的输出(通常是分类任务中的软标签)就被用来指导小的模型(学生模型)的训练。学生模型的训练是在教师模型固定之后进行的,因此称之为“离线”蒸馏。学生模型试图模仿教师模型的输出,以此来达到比自己直接在数据集上训练更好的性能。
Online Distillation
在线蒸馏指的是教师模型和学生模型同步训练的情况。在这个方案中,不需要预先训练一个固定的教师模型,而是让教师和学生在同一时间内相互学习。有时候,教师模型会在训练过程中动态更新,即学生模型的训练过程中教师模型也在不断地学习和改进。这种方法允许学生模型从教师模型的即时反馈中学习,而教师模型也可以根据学生模型的进展进行调整。在线蒸馏使得模型训练更加灵活,因为它不需要一个训练完成的教师模型作为起点。
Self Distillation
自蒸馏是一个相对较新的概念,它不涉及两个不同的模型,而是在同一个模型上进行迭代训练。一个模型首先被训练,然后它自己的输出被用作下一轮训练的目标。简单来说,模型首先以一定方式训练(例如使用硬标签),一旦完成,模型的预测(软标签)被用来再次训练模型。这个过程可以重复多次,每次模型都尝试模仿自己先前版本的输出。自蒸馏允许模型在没有外部教师模型的情况下提高其性能。这种方案的优势是简单和成本低。