发布平台:2022 CVPR
论文链接:https://arxiv.org/abs/2203.08679
代码链接:https://github.com/megviiresearch/mdistiller
创新点
揭示了经典KD是高度耦合的公式,为此局限了KD的潜力。将经典KD为两个层次:(1)是对目标类和所有非目标类(TCKD)的二值预测,(2)是对每个非目标类(NCKD)的多类别预测。
TCKD
通过二元对数蒸馏法传递知识。这意味着只提供目标类的预测,而每个非目标类的具体预测是未知的。一个合理的假设是,TCKD传递了关于训练样本 "难度 "的知识。即,该知识描述了识别每个训练样本的难度。
NCKD
非目标类间关系的知识。即,传统KD引导学生网络学习的类间关系。
对KD进行解耦分析
如图b所示。NCKD损失项由一个与教师对目标类的预测置信度呈负相关的系数进行加权。因此,教师网络对正确目标类的预测置信度越大,非目标类的权重就会越小。
备注
通过逻辑蒸馏,我们希望学生网络学习教师网络的预测输出,即学习教师网络提供的不同类别之间的类间关系。
理想的逻辑蒸馏方式:当教师网络预测正确时,引导学生网络在学习目标类知识的同时,充分学习教师网络提供的不同类别之间的类间关系。
由于传统KD被证实是一个耦合的公式,即教师网络预测正确时,大部分数值会集中于正确的目标类,而少部分数值聚集于非目标类,即非目标类的权值与目标类的预测置信度呈负相关。
为此,DKD的做法是,对目标类的知识、非目标类之间的类间关系知识,分开引导学生网络学习。
损失函数
传统KD
表示教师网络经过softmax软化后的预测输出,
表示学生网络经过sofemax软化后的预测输出。
表示目标类之间的相似度。
表示非目标类之间的相似度。
表示非目标类损失项的系数(与教师对目标类的预测置信度呈负相关的系数)。
DKD
通过引入定值
和
分别作为TCKD和NCKD的权重,以此完成KD的解耦过程。
与
表示教师网络与学生网络对目标类预测的二进制概率,
与
表示教师网络与学生网络对所有非目标类预测的二进制概率。
计算公式如下所示:
定义非目标类之间的独立概率分布(不考虑第t类,即目标类)。每个元素的计算公式如下所示: