本文介绍一种知识蒸馏的方法(Similarity-Preserving Knowledge Distillation)。作者针对分类任务,通过“保留相似性”实现更好的蒸馏。
1. 主要思想
作者的构思主要基于一个核心前提:如果两个输入在教师网络中有着高度相似的激活,那么引导学生网络对该输入同样产生高的相似激活(反之亦然)。
上图指示了CIFAR-10中10000张图片在教师网络最后一个卷积层的激活值的均值。这里分成了十类,每一类对应相邻的1000张图片。可见,相邻的1000张图片的激活情况是类似的,而不同类别之间有明显差异。
下图中展示了对于CIFAR-10测试集上的数个batch
的可视化结果。
图中:
- 每一列表示一个单独的
batch
,两个网络都是一致的。 - 每个
batch
的图像中,对于样本的顺序已经通过其真值类别分组。图像表示,相同类别的图像有着更大的相似性,不同类别的图像相似性较小。 - 上下对比也可以看出来,对于复杂模型(下面),对角线上的值更突出,表示模型效果更好。
2. 网络结构
这里是对大小为b
的batch
中。所有图片一起编码,进而得到一个bxb
的相似性矩阵。
3. 损失函数
保留相似性知识蒸馏损失(similarity-preserving knowledge distillation loss):
-
计算学生和老师特征图的相似度矩阵
其中,Qs: (batchsize, h*w)
,Gs: (batchsize, batchsize)
为对称矩阵,表示学生特征图的相似度矩阵。
对Gs
进行归一化后,G_s(i,j)
表示学生特征图第i
和第j
批数据的相似程度。
GT
同理表示老师特征图的相似度矩阵。 -
计算相似度矩阵之差的均值
如上式,先对相似度矩阵作差,然后计算对应位置元素F范数的和,除以元素总数。
代码实现:
class Similarity(nn.Module):
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(Similarity, self).__init__()
def forward(self, g_s, g_t):
'''对于老师和学生网络输出的每一个元素计算相似性损失'''
return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def similarity_loss(self, f_s, f_t):
'''损失函数'''
# bsz: batch size
# f_s: [batch_size, h, w]
bsz = f_s.shape[0]
# f_s: [batch_size, h, w] -> [batch_size, hxw]
# f_t: [batch_size, h, w] -> [batch_size, hxw]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
# G_s: [batch_size, batch_size]为对称矩阵
G_s = torch.mm(f_s, torch.t(f_s))
G_t = torch.mm(f_t, torch.t(f_t))
# 归一化后,G_s(i,j)表示students features第i和第j批数据的相似程度
G_s = torch.nn.functional.normalize(G_s)
G_t = torch.nn.functional.normalize(G_t)
# 相似度矩阵的差
G_diff = G_t - G_s
# 计算相似度差值矩阵G_diff的元素均值
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss
4. 训练
# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = Similarity()
for idx, data in enumerate(train_loader):
# ===================forward=====================
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()
其中的feat_s
是中间特征层,例如对于resnet8
if preact:
return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
return [f0, f1, f2, f3, f4], x
论文理解部分参考文献:Similarity-Preserving Knowledge Distillation