本文介绍知识蒸馏的经典论文(Distilling the knowledge in a neural network)。核心思想是通过迁移知识,从而通过训练好的大模型得到更加适合推理的小模型。
1. 核心思想
文章的核心思想就是提出用soft target
来辅助hard target
一起训练,而soft target
来自于大模型的预测输出:
1、训练大模型:先用hard target
,也就是正常的标签训练大模型。
2、计算soft target
:利用训练好的大模型来计算soft target
。也就是大模型“软化后”再经过softmax
的输出。
3、训练小模型,在小模型的基础上再加一个额外的soft target
的损失函数,通过alpha来调节两个损失函数的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。
2. 损失函数
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
知识蒸馏的关键是损失函数的设计,它包括普通的交叉熵损失和建立在soft target
基础上的损失。
hard target
包含的信息量(信息熵)很低,soft target
包含的信息量大,拥有不同类之间关系的信息。
比如,同时分类驴和马的时候,尽管某张图片是马,但是soft target
就不会像hard target
那样只有马的index
处的值为1,其余为0,而是在驴的部分也会有概率。
这样的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft
信息存在于概率中,以及标签之间的高低相似性都存在于soft target
中。
但是如果soft target
是像这样的信息[0.98 0.01 0.01]
,就意义不大了,所以需要在softmax
中增加温度参数T
(这个设置在最终训练完之后的推理中是不需要的)。增加softmax
后的蒸馏损失函数:
综合损失函数:
蒸馏损失的代码实现:
# ==============================蒸馏损失===============================
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
# student网络输出软化后结果
# log_softmax与softmax没有本质的区别,只不过log_softmax会得到一个正值的loss结果。
p_s = F.log_softmax(y_s/self.T, dim=1)
# # teacher网络输出软化后结果
p_t = F.softmax(y_t/self.T, dim=1)
# 蒸馏损失采用的是KL散度损失函数
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
参考文献
深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network)