知识蒸馏学习笔记-manifold-distillation

论文地址: https://arxiv.org/pdf/2107.01378.pdf

开源代码: https://github.com/Hao840/manifold-distillation?tab=readme-ov-file

这是一种适用于视觉transformer的蒸馏方法,以提高学生网络的性能。该方法利用流形学习构建了当前批次内所有样本的patch之间的关系,并将这些关系传递给学生网络,以获取教师网络的中间特征。与此同时,还结合了对网络输出logits的蒸馏,通过辅助学生网络的训练来提高其性能。

为了方便实际训练,作并没有直接构建批次内所有patch间的关系,而是按照特定的模式对这些关系进行采样,以减少内存消耗和计算量。同时,为了减少采样误差,还引入了随机采样项。

与其他适用于视觉transformer的蒸馏方法相比,该方法更充分利用了教师网络所携带的信息,不要求教师网络和学生网络的特征维度对齐,因此具有较高的灵活性。这使得在选择教师网络时更加自由。实验证明,该蒸馏方法显著提升了学生transformer网络的性能,例如,在ImageNet数据集和DeiT-Tiny学生网络中获得了0.6%的Top1准确率提升,相较于DeiT中的蒸馏方法。

小辉问:为什么说所有样本的patch之间的关系就能代表着教师网络中间的特征?

小G答:在视觉transformer中,图像被分割成许多patch,并且每个patch都经过transformer网络的处理,生成相应的特征表示。这些patch之间的关系在特定任务中包含了重要的信息,因为它们反映了图像中不同区域之间的语义关联和上下文信息。通过学习这些patch之间的关系,可以捕捉到更高层次的语义和结构信息,从而代表了教师网络中间的特征。这种方法的基本思想是将patch之间的关系建模为一个流形(manifold),即一种在高维空间中的曲面结构。这个流形反映了patch之间的相似性和差异性,通过传递这个流形结构给学生网络,使得学生网络能够更好地学习到教师网络中的特征表示。总体而言,patch之间的关系被认为是包含了图像语义信息的一种表达方式,因此通过学习这种关系,可以有效地捕捉到图像中的特征和结构,提高了蒸馏方法的性能.

小辉问:网络输出logits的蒸馏,通过辅助学生网络的训练来提高其性能。这句话不是很理解,能举个例子详细讲讲吗?

小G答:在深度学习中,蒸馏(Knowledge Distillation)是一种通过教导一个较小的模型(学生网络)去模拟一个较大、性能更好的模型(教师网络)的方法。具体而言,对于分类任务,通常是通过教导学生网络输出与教师网络相似的概率分布,即logits(逻辑回归的输入)。举个例子,假设有一个图像分类任务,教师网络(大模型)输出了对于某张图像的概率分布,例如,[0.1, 0.8, 0.1] 表示模型认为该图像属于第二个类别的概率最高。这就是教师网络的logits。现在,我们希望通过训练一个较小的学生网络,使得学生网络也能输出类似的概率分布。通过蒸馏,我们会调整学生网络的参数,使得它的输出概率分布与教师网络的概率分布尽可能相似。这就是“网络输出logits的蒸馏”。通过这种方式,学生网络能够从教师网络中获得更多的知识,包括类别之间的关系和任务中的重要特征。这种蒸馏的过程有助于提高学生网络的性能,尤其是在资源受限的情况下,通过借助大模型的知识来训练小模型,实现性能的平衡。

上图是一个batch同一个图片和不同图片的每一个split后的每个 patch之间的流行空间

视觉transformer模型层

模型层包含两个部分: 多头自注意力(MSA)和多层感知机(MLP)。 输入是拆分的 patch ,每个 patch 经过处理后形成嵌入表示。

模型层的操作分为两步:首先通过MSA处理,然后通过MLP处理。 每一步的结果都会与之前的嵌入相加。

小辉问:为什么后面还要加个Xe?

小G答:在深度学习中,每一步的结果与之前的嵌入相加的设计被称为残差连接,它具有以下目的和原理:

1.残差连接与梯度消失问题: 通过残差连接,模型可以更轻松地学习残差部分,从而缓解梯度消失问题。这是由何凯明等人提出的一种深度神经网络结构,通过引入跳跃连接,使得模型更容易训练深层网络。

2.信息融合: 将新的信息与之前的嵌入相加,有助于逐层融合信息,使模型能够学习多层次的特征表示,更好地捕捉输入数据的复杂性和抽象性。

3.稳定性: 相加的操作使得模型更加稳定。即使新的嵌入信息可能引入一些变化,但由于残差连接,模型不会完全丧失之前学到的信息。

4.残差连接的原理: 残差连接的核心思想是引入一个跳跃连接,直接将输入信号绕过一个或多个层,将其加到输出上。这有效地缓解了深层网络的训练问题,提高了网络的性能和训练效率。

5.梯度传播和优化: 残差连接通过提供一条捷径,使得梯度可以直接通过跳跃连接传递,从而减缓梯度的变化。这使得网络在反向传播时更容易学习恒等映射,防止梯度消失和梯度爆炸问题。

总体而言,残差连接的设计有助于提高模型的训练稳定性,减缓梯度消失问题,并促使模型学习更丰富的层次化特征表示。这对于训练深层网络,特别是在图像处理和计算机视觉任务中,具有重要的作用。

知识蒸馏(KD)损失函数

KD是一种用于模型压缩的方法,通过将大型教师模型的预测作为学生模型的学习目标。 损失函数包含两部分:交叉熵损失(HCE)和Kullback-Leibler散度损失(HKL)。 交叉熵损失用于确保学生的预测与标签一致。 KL散度损失用于确保学生的预测与教师的预测相似。

Manifold Relation Map计算

对于批量大小为B的样本,将学生层和教师层的特征分别表示为:

其中,DT和DS是嵌入维度,首先对最后一个维度进行规范化,然后计算manifold结构,即manifold relation map,公式如下:

Manifold Distillation Loss

计算学生层和教师层的manifold relation map之间的差异,使用以下损失函数:

计算资源优化

尽管manifold relation map的计算非常耗费资源,但为了简化计算,将manifold relation map分解为三个部分:一个intra-image relation map(图像内关系映射),一个inter-image relation map(图像间关系映射),和一个随机抽样的relation map。

Intra-image Relation Map: 衡量同一图像内部patch之间的关系。

Inter-image Relation Map: 衡量不同图像之间的patch之间的关系。

随机抽样的Relation Map: 随机选择一小部分patch进行关系映射。

通过这种分解,可以显著减少计算复杂性和内存需求,使fine-grained manifold distillation方法更具可扩展性。

(a) Intra-image关系映射:表示在同一图像内部计算的关系映射。每个颜色填充的组中计算了一个关系映射。

(b) Inter-image关系映射:表示在不同图像之间计算的关系映射。每个颜色填充的组中计算了一个关系映射。

(c) 随机抽样的Relation Map:表示通过随机抽样计算的关系映射。每个颜色填充的组中计算了一个关系映射。

计算同一图像内部和不同图像之间patch-level manifold distillation loss

不同图像之间的损失(Linter)

随机抽样的patch-level manifold distillation loss

整体损失函数

将KD损失与分解的manifold distillation损失组合,得到整体损失函数

总结

该方法将visison transformer为一个特征投影器,逐层将图像patch嵌入到一系列平滑的manifold空间中。在每对手动选择的教师-学生层之间,目标是教导学生层输出具有与教师层相同的patch-level manifold结构的特征。

  • 31
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值