知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION

知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型的dark knowledge,从而帮助简单模型提高性能。

论文地址:https://arxiv.org/abs/1612.03928
github地址:https://github.com/szagoruyko/attention-transfer

PAYING MORE ATTENTION TO ATTENTION: IMPROVING THE PERFORMANCE OF CONVOLUTIONAL NEURAL NETWORKS VIA ATTENTION TRANSFER

这篇论文主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。
在这里插入图片描述
图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图 C × H × W C \times H \times W C×H×W通过映射变换成特征 H × W H\times W H×W。作者将每层channel平方相加获得特征图对应的注意力图。

在这里插入图片描述
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。
在这里插入图片描述
我们发现精度越高的网络对应的注意力图往往可以抓住图片分类的特征。同时选择合适的变换方程对注意力图的获取也至关重要。
在这里插入图片描述
论文中,作者loss分为两部分,第一部分是分类loss就是简单的交叉熵损失函数来实现分类,后一部分是衡量复杂模型于简单模型注意力图差异的距离损失函数。作者着重强调对注意力图进行归一化的重要性,在学生网络的训练中起很大作用。

def at(x):
    y = F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
    return y

def at_loss(x, y):
    return (at(x) - at(y)).pow(2).mean()

作者提供的代码,实现了对注意力图归一化,对差值平方再求平均。
在这里插入图片描述
从cifar-10结果可以看出,作者提出的Activation-based attention transfer可以有效的学习复杂模型提取特征的能力,从而提高简单模型的分类精度。

  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值