蒸馏论文二(Attention Transfer)

本系列文章介绍一些知识蒸馏领域的经典文章。

知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型,从而帮助简单模型提高性能。

1. Attention Transfer原理

论文Paying more attention to attention主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样,简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。

在这里插入图片描述
其中,图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图[C, H , W]通过映射变换成特征[H, W]。作者将每层通道平方相加获得特征图对应的注意力图。

在这里插入图片描述
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。

在这里插入图片描述

2. 损失函数

论文中,作者将损失分为两部分:在这里插入图片描述
第一部分是分类损失是简单的交叉熵损失函数,作用是实现分类。

第二部分是衡量复杂模型于简单模型注意力图差异的函数。首先注意力图进行归一化,即除以自身的模值,然后计算两个注意力图的p范数。

第三部分也是衡量复杂模型于简单模型注意力图差异的函数,在实际代码实现中使用KL散度实现。KL散度是用来衡量两个概率分布之间的相似性的函数。不了解KL散度的见这里
在这里插入图片描述
参考代码,以下是使用attention transfer技巧的关键部分代码。

class Attention(nn.Module):
    """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
    via Attention Transfer
    code: https://github.com/szagoruyko/attention-transfer"""
    def __init__(self, p=2):
        super(Attention, self).__init__()
        self.p = p

    def forward(self, g_s, g_t):
        '''对于老师和学生网络输出的每一个元素计算损失'''
        return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

    def at_loss(self, f_s, f_t):
        '''损失函数'''
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        # 通过pooling将teacher features和student features调整为统一大小
        # pooling成两者中较小的size
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass

        # 返回归一化后teacher features和student features的欧氏距离
        return (self.at(f_s) - self.at(f_t)).pow(2).mean()

    def at(self, f):
        '''归一化'''
        return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

3. 训练

# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = Attention()

for idx, data in enumerate(train_loader):
	# ===================forward=====================
	loss_cls = criterion_cls(logit_s, target)
	loss_div = criterion_div(logit_s, logit_t)
	        
	g_s = feat_s[1:-1]
	g_t = feat_t[1:-1]
	loss_group = criterion_kd(g_s, g_t)
	loss_kd = sum(loss_group)
	
	loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
    # ===================backward=====================
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ===================meters=====================
    batch_time.update(time.time() - end)
    end = time.time()

其中的feat_s是中间特征层,例如对于resnet8

if preact:
    return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
    return [f0, f1, f2, f3, f4], x

论文理解部分参考文献:
知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值