本系列文章介绍一些知识蒸馏领域的经典文章。
知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型,从而帮助简单模型提高性能。
1. Attention Transfer原理
论文Paying more attention to attention主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样,简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。
其中,图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图[C, H , W]
通过映射变换成特征[H, W]
。作者将每层通道平方相加获得特征图对应的注意力图。
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。
2. 损失函数
论文中,作者将损失分为两部分:
第一部分是分类损失是简单的交叉熵损失函数,作用是实现分类。
第二部分是衡量复杂模型于简单模型注意力图差异的函数。首先注意力图进行归一化,即除以自身的模值,然后计算两个注意力图的p范数。
第三部分也是衡量复杂模型于简单模型注意力图差异的函数,在实际代码实现中使用KL散度实现。KL散度是用来衡量两个概率分布之间的相似性的函数。不了解KL散度的见这里。
参考代码,以下是使用attention transfer
技巧的关键部分代码。
class Attention(nn.Module):
"""Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
via Attention Transfer
code: https://github.com/szagoruyko/attention-transfer"""
def __init__(self, p=2):
super(Attention, self).__init__()
self.p = p
def forward(self, g_s, g_t):
'''对于老师和学生网络输出的每一个元素计算损失'''
return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def at_loss(self, f_s, f_t):
'''损失函数'''
s_H, t_H = f_s.shape[2], f_t.shape[2]
# 通过pooling将teacher features和student features调整为统一大小
# pooling成两者中较小的size
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
# 返回归一化后teacher features和student features的欧氏距离
return (self.at(f_s) - self.at(f_t)).pow(2).mean()
def at(self, f):
'''归一化'''
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
3. 训练
# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = Attention()
for idx, data in enumerate(train_loader):
# ===================forward=====================
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()
其中的feat_s
是中间特征层,例如对于resnet8
if preact:
return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
return [f0, f1, f2, f3, f4], x
论文理解部分参考文献:
知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION