目录
代码:https://github.com/GzyAftermath/CAT-KD
1.论文核心
代码:https://github.com/GzyAftermath/CAT-KD
1.1研究背景
- 现有的知识蒸馏方法虽性能优异,但很难解释它们传递的知识如何有助于提高学生网络的性能。
- 尽管以前的工作AT[1]已经验证了转移注意力的有效性,但它并没有呈现注意力在分类过程中扮演的角色。且相较于logits和特征转移的知识蒸馏方法[2,3]并不具有竞争力。
1.2创新点
- 我们提出了类注意力转移,并用它来证明识别输入的类判别区域的能力,这是CNN进行分类的关键,可以通过转移(Class Attention Map)CAM来获得和增强。
- 我们提出了转移CAM的几个有趣的特性,这有助于更好地理解CNN。
- 我们将CAT应用于知识提取,并命名为CAT-KD。在具有高可解释性的同时,CAT-KD在多个基准测试中实现了最先进的性能。
2.具体方法
2.1重新审视CNN的结构
令为最后一个卷积层生成的特征图,其中C,W,H为特征的通道数,宽度和高度。令为在(x,y)位置上的第j个通道中F的激活,GAP为全局平均池化,则CNN模型的logits计算可以写为:
其中指的是第i个类别的logits,为的与类别i对应的全连接层权重。则我们可以通过以下方式获得对应于类别i的CAM:
联立上述两个公式, 可以被写为:
从而可以通过计算CAM的平均值,计算logits,基于此我们将CNN结构进行转换,如下图所示。
2.2类别注意力转移(CAT)
CAT的目的是检验一个模型是否可以通过只传递CAM来获得识别输入的类别区分区域的能力。
对于给定的输入,令为转换后产生的CAM结构,其中K代表类别数,W,H为宽度和高度。表示A的第i个通道,也是类别i的CAM,此外,我们使用了平均池化(2*2)降低CAM的分辨率,用以提高CAT的性能。 那么CAT的损失函数可以定义为:
结构图如下图所示:
2.3CAT-KD
将CAT用于知识蒸馏,命名为CAT-KD,则CAT-KD的损失函数为:
我们提出的方法是通过提高学生模型识别类判别区域的能力提高其性能,具有可解释性。
[1]Paying more attention to attention: improving the performance of convolutional neural networks via attention transfer.
[2]Distilling knowledge via knowledge review.
[3]Decoupled knowledge distillation.