蒸馏论文三(Similarity-Preserving)

本文介绍一种知识蒸馏的方法(Similarity-Preserving Knowledge Distillation)。作者针对分类任务,通过“保留相似性”实现更好的蒸馏。

1. 主要思想

作者的构思主要基于一个核心前提:如果两个输入在教师网络中有着高度相似的激活,那么引导学生网络对该输入同样产生高的相似激活(反之亦然)。
在这里插入图片描述
上图指示了CIFAR-10中10000张图片在教师网络最后一个卷积层的激活值的均值。这里分成了十类,每一类对应相邻的1000张图片。可见,相邻的1000张图片的激活情况是类似的,而不同类别之间有明显差异。

下图中展示了对于CIFAR-10测试集上的数个batch的可视化结果。

在这里插入图片描述
图中:

  1. 每一列表示一个单独的batch,两个网络都是一致的。
  2. 每个batch的图像中,对于样本的顺序已经通过其真值类别分组。图像表示,相同类别的图像有着更大的相似性,不同类别的图像相似性较小。
  3. 上下对比也可以看出来,对于复杂模型(下面),对角线上的值更突出,表示模型效果更好。

2. 网络结构

在这里插入图片描述
这里是对大小为bbatch中。所有图片一起编码,进而得到一个bxb的相似性矩阵。

3. 损失函数

保留相似性知识蒸馏损失(similarity-preserving knowledge distillation loss):

  1. 计算学生和老师特征图的相似度矩阵
    在这里插入图片描述
    其中,Qs: (batchsize, h*w)Gs: (batchsize, batchsize)为对称矩阵,表示学生特征图的相似度矩阵。
    Gs进行归一化后,G_s(i,j)表示学生特征图第i和第j批数据的相似程度。
    在这里插入图片描述
    GT同理表示老师特征图的相似度矩阵。

  2. 计算相似度矩阵之差的均值
    在这里插入图片描述
    如上式,先对相似度矩阵作差,然后计算对应位置元素F范数的和,除以元素总数。

代码实现:

class Similarity(nn.Module):
    """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
    def __init__(self):
        super(Similarity, self).__init__()

    def forward(self, g_s, g_t):
        '''对于老师和学生网络输出的每一个元素计算相似性损失'''
        return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

    def similarity_loss(self, f_s, f_t):
        '''损失函数'''
        # bsz: batch size
        # f_s: [batch_size, h, w]
        bsz = f_s.shape[0]

        # f_s: [batch_size, h, w] -> [batch_size, hxw]
        # f_t: [batch_size, h, w] -> [batch_size, hxw]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)

        # G_s: [batch_size, batch_size]为对称矩阵
        G_s = torch.mm(f_s, torch.t(f_s))
        G_t = torch.mm(f_t, torch.t(f_t))

        # 归一化后,G_s(i,j)表示students features第i和第j批数据的相似程度
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.nn.functional.normalize(G_t)

        # 相似度矩阵的差
        G_diff = G_t - G_s

        # 计算相似度差值矩阵G_diff的元素均值
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)

        return loss

4. 训练

# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = Similarity()

for idx, data in enumerate(train_loader):
	# ===================forward=====================
	loss_cls = criterion_cls(logit_s, target)
	loss_div = criterion_div(logit_s, logit_t)
	        
	g_s = feat_s[1:-1]
	g_t = feat_t[1:-1]
	loss_group = criterion_kd(g_s, g_t)
	loss_kd = sum(loss_group)
	
	loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
    # ===================backward=====================
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ===================meters=====================
    batch_time.update(time.time() - end)
    end = time.time()

其中的feat_s是中间特征层,例如对于resnet8

if preact:
    return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
    return [f0, f1, f2, f3, f4], x

论文理解部分参考文献:Similarity-Preserving Knowledge Distillation

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值