蒸馏论文三(Similarity-Preserving)

本文介绍一种知识蒸馏的方法(Similarity-Preserving Knowledge Distillation)。作者针对分类任务,通过“保留相似性”实现更好的蒸馏。

1. 主要思想

作者的构思主要基于一个核心前提:如果两个输入在教师网络中有着高度相似的激活,那么引导学生网络对该输入同样产生高的相似激活(反之亦然)。
在这里插入图片描述
上图指示了CIFAR-10中10000张图片在教师网络最后一个卷积层的激活值的均值。这里分成了十类,每一类对应相邻的1000张图片。可见,相邻的1000张图片的激活情况是类似的,而不同类别之间有明显差异。

下图中展示了对于CIFAR-10测试集上的数个batch的可视化结果。

在这里插入图片描述
图中:

  1. 每一列表示一个单独的batch,两个网络都是一致的。
  2. 每个batch的图像中,对于样本的顺序已经通过其真值类别分组。图像表示,相同类别的图像有着更大的相似性,不同类别的图像相似性较小。
  3. 上下对比也可以看出来,对于复杂模型(下面),对角线上的值更突出,表示模型效果更好。

2. 网络结构

在这里插入图片描述
这里是对大小为bbatch中。所有图片一起编码,进而得到一个bxb的相似性矩阵。

3. 损失函数

保留相似性知识蒸馏损失(similarity-preserving knowledge distillation loss):

  1. 计算学生和老师特征图的相似度矩阵
    在这里插入图片描述
    其中,Qs: (batchsize, h*w)Gs: (batchsize, batchsize)为对称矩阵,表示学生特征图的相似度矩阵。
    Gs进行归一化后,G_s(i,j)表示学生特征图第i和第j批数据的相似程度。
    在这里插入图片描述
    GT同理表示老师特征图的相似度矩阵。

  2. 计算相似度矩阵之差的均值
    在这里插入图片描述
    如上式,先对相似度矩阵作差,然后计算对应位置元素F范数的和,除以元素总数。

代码实现:

class Similarity(nn.Module):
    """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
    def __init__(self):
        super(Similarity, self).__init__()

    def forward(self, g_s, g_t):
        '''对于老师和学生网络输出的每一个元素计算相似性损失'''
        return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

    def similarity_loss(self, f_s, f_t):
        '''损失函数'''
        # bsz: batch size
        # f_s: [batch_size, h, w]
        bsz = f_s.shape[0]

        # f_s: [batch_size, h, w] -> [batch_size, hxw]
        # f_t: [batch_size, h, w] -> [batch_size, hxw]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)

        # G_s: [batch_size, batch_size]为对称矩阵
        G_s = torch.mm(f_s, torch.t(f_s))
        G_t = torch.mm(f_t, torch.t(f_t))

        # 归一化后,G_s(i,j)表示students features第i和第j批数据的相似程度
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.nn.functional.normalize(G_t)

        # 相似度矩阵的差
        G_diff = G_t - G_s

        # 计算相似度差值矩阵G_diff的元素均值
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)

        return loss

4. 训练

# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = Similarity()

for idx, data in enumerate(train_loader):
	# ===================forward=====================
	loss_cls = criterion_cls(logit_s, target)
	loss_div = criterion_div(logit_s, logit_t)
	        
	g_s = feat_s[1:-1]
	g_t = feat_t[1:-1]
	loss_group = criterion_kd(g_s, g_t)
	loss_kd = sum(loss_group)
	
	loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
    # ===================backward=====================
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ===================meters=====================
    batch_time.update(time.time() - end)
    end = time.time()

其中的feat_s是中间特征层,例如对于resnet8

if preact:
    return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
    return [f0, f1, f2, f3, f4], x

论文理解部分参考文献:Similarity-Preserving Knowledge Distillation

加速基于相似性模型匹配的On-The-Fly相似性保持哈希 摘要: 在软件工程中,模型匹配是一项核心任务,广泛应用于模型驱动工程、软件重构、需求管理、代码检查等领域。由于模型通常包含大量的元素和复杂的结构,模型匹配问题变得越来越具有挑战性。相似性匹配是一种流行的模型匹配方法,它通过计算语义相似度来匹配模型元素。然而,由于相似性匹配算法的计算复杂度很高,导致它们的效率低下。 为了提高相似性匹配的效率,我们提出了一种基于On-The-Fly相似性保持哈希的加速方法。该方法利用哈希表将元素映射到桶中,并在桶中使用相似性保持哈希函数计算相似性,从而避免了在匹配过程中进行昂贵的相似性计算。此外,我们还提出了一种基于哈希冲突的剪枝策略,以进一步提高匹配效率。 我们在多个数据集上进行了实验,并与现有相似性匹配算法进行了比较。实验结果表明,我们的方法可以显著提高匹配效率,同时保持高精度。 关键词:模型匹配,相似性匹配,哈希,剪枝 Abstract: In software engineering, model matching is a core task widely applied in model-driven engineering, software refactoring, requirement management, code inspection, etc. Due to the fact that models usually contain a large number of elements and complex structures, model matching problems become increasingly challenging. Similarity-based matching is a popular model matching approach that matches model elements by computing semantic similarities. However, due to the high computational complexity of similarity-based matching algorithms, they suffer from poor efficiency. To improve the efficiency of similarity-based matching, we propose an acceleration method based on On-The-Fly similarity preserving hashing. This method uses a hash table to map elements to buckets and employs similarity preserving hash functions to compute similarities within buckets, thus avoiding expensive similarity computations during the matching process. In addition, we propose a hash conflict-based pruning strategy to further improve the matching efficiency. We conduct experiments on multiple datasets and compare our method with existing similarity-based matching algorithms. Experimental results show that our method can significantly improve the matching efficiency while maintaining high accuracy. Keywords: Model matching, similarity-based matching, hashing, pruning.
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只蓝鲸鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值