本文介绍了一种知识蒸馏的方法(Neuron Selectivity Transfer)
1. 主要思想
如图所示为神经元选择性迁移的架构。学生网络不仅利用真正的标签训练,而且还模仿了教师网络中间层的激活分布。图中的每个点或三角形表示其对应的滤波器的激活图。
2. 损失函数
最大平均差异(MMD)用作损失函数来衡量教师和学生特征之间的差异。MMD的想法就是求两个随机变量在高维空间中均值的距离
可以应用内核技巧展开,将上式子转化成如下形式:
在应用中,上述x
和y
为学生网络和教师网络归一化后的特征图,因此:
这里的k
为核函数,有以下几种可采用的形式:
在应用中,作者采用了d=2, c=0
的多项式核函数。
class NSTLoss(nn.Module):
"""like what you like: knowledge distill via neuron selectivity transfer"""
def __init__(self):
super(NSTLoss, self).__init__()
pass
def forward(self, g_s, g_t):
return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def nst_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
# 通过pooling将teacher features和student features调整为统一大小
# pooling成两者中较小的size
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
f_s = F.normalize(f_s, dim=2)
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
f_t = F.normalize(f_t, dim=2)
# set full_loss as False to avoid unnecessary computation
full_loss = True
if full_loss:
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
- 2 * self.poly_kernel(f_s, f_t).mean())
else:
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
def poly_kernel(self, a, b):
'''d=2, c=0的多项式核函数'''
return (a@b.transpose(-1,-2)).transpose(-1,-2).pow(2)
3. 训练
# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = NSTLoss()
for idx, data in enumerate(train_loader):
# ===================forward=====================
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()
其中的feat_s
是中间特征层,例如对于resnet8
if preact:
return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
return [f0, f1, f2, f3, f4], x