DeiT : Training data-efficient image transformers & distillation through attention

1. 引言

        受自然语言处理中基于注意的模型成功的激励,人们对利用convnets中注意机制的架构越来越感兴趣。最近,一些研究人员提出了混合架构,将Transformer成分移植到convnets中来解决视觉任务。

        Vision Transformer在数据量不足的情况下不能很好地泛化。这些模型的训练需要大量的计算资源。引入了一种基于令牌的策略,DeiT有利地取代了Transformer的常规蒸馏。本工作做出了以下贡献:我们的神经网络不包含卷积层,可以在没有外部数据的情况下,在ImageNet上取得与最先进的结果相竞争的结果。它们是在一个有4个gpu的节点上学习的,需要3天时间。新模型DeiT-S和DeiT-Ti参数更少,可以看作是ResNet-50和ResNet-18的对应。引入了一个基于蒸馏令牌的新蒸馏过程,该蒸馏令牌的作用与类别令牌相同,不同之处是它旨在重现教师估计的标签。两个令牌通过注意在Transformer中相互作用。这种特定于Transformer的策略比普通蒸馏的性能要好得多。

2. 相关工作

2.1 知识蒸馏

        知识蒸馏是一种训练范式,其中学生模型利用来自强大教师网络的“软”标签进行学习。这些软标签是教师网络的softmax函数的输出向量,而不仅仅是得分最高的那个“硬”标签。这种训练方式能提升学生模型的性能,也可以视为将教师模型“压缩”到更小的学生模型中的一种形式。一方面,教师的软标签与标签平滑有类似的效果,即减少模型对硬标签的过度自信,提高泛化能力。另一方面,Wei等人的研究表明,教师的监督考虑了数据增强的影响,这有时会导致真实标签与图像之间的不一致。例如,在一张标记为“猫”的图片中,虽然整体是广阔的风景,但角落里有一只小猫。如果数据增强裁剪掉了小猫,那么实际上这张图片的标签就隐含地改变了。

        知识蒸馏能够以软性的方式将学生模型中的归纳偏置从教师模型中迁移过来,而这些偏置在教师模型中可能是以硬性方式嵌入的。在我们的研究中,探索了使用卷积神经网络(ConvNet)或Transformer作为教师模型来指导Transformer学生模型的蒸馏过程,旨在将卷积偏置以软性方式引入Transformer中。

3. Vision Transformer概述

3.1 多头自注意力层

        查询向量q\in R^d与一组 k个键向量(组合成一个矩阵K\in R^{k\times d})使用内积进行匹配。然后用softmax函数对这些内积进行缩放和规范化以获得 k 个权重。注意力的输出是一组 k 个值向量(打包成V\in R^{k\times d})的加权和。对于一个包含N个查询向量的序列(封装到Q\in R^{N\times d}中),它产生一个输出矩阵(大小为N × d): 

 \mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(QK^{\top}/\sqrt{d})V,\quad(1)

其中Softmax函数应用于输入矩阵的每一行。\sqrt{d} 提供了适当的规格化。为了防止当d较大时,点积结果会非常大,导致softmax函数的梯度变得非常小,从而引发梯度消失问题。缩放因子\sqrt{d} 有助于保持梯度的稳定性。

        在Transformer模型中,缩放点积注意力机制是一种计算序列中元素之间相互依赖性的方法。具体来说,对于输入矩阵XN\times D维的输入向量),首先通过三个线性变换(即矩阵乘法)分别生成查询(Query, Q)、键(Key, K)和值(Value, V)矩阵。这些变换由权重矩阵W_Q,W_K.W_V定义,这些权重矩阵是模型需要学习的参数。

        多头自注意力机制是Transformer中用于增强模型表示能力的一种方法。它通过并行地运行多个自注意力层(即“头”)来实现,每个头都能独立地关注输入序列的不同部分。每个头都会执行上述的缩放点积注意力机制,但使用不同的线性变换(即不同的权重矩阵W_Q,W_K.W_V)来生成查询、键和值矩阵。这允许每个头捕捉到输入序列中不同的依赖关系。每个头都会输出一个N×d的序列(其中N是序列长度,d是每个头的输出维度)。然后,这些序列被重新排列成一个N×dh的序列(h是头的数量),最后通过一个线性层将这个序列映射回N×D的维度,以便与模型的其余部分兼容。

3.2 将Transformer应用于图像

        为了将Transformer模型应用于图像处理,基于ViT模型进行工作,该模型以一种简单而优雅的方式将输入图像视为输入令牌的序列。在ViT中,固定大小的输入RGB图像被分解成一批固定大小为16x16像素的N个块(N=14x14)。每个块通过一个线性层进行投影,保持其整体维度3x16x16=768不变。

        为了构建一个完整的Transformer块,在多头自注意力(MSA)层之上添加了一个前馈网络(FFN)。这个FFN由两个线性层组成,中间通过一个GeLu激活函数分隔。第一个线性层将维度从D扩展到4D,第二个线性层再将其缩减回D。MSA和FFN都通过跳跃连接和层归一化作为残差算子进行操作。

        然而,Transformer块对块嵌入的顺序是不变的,因此忽略了它们的位置信息。为了解决这个问题,通过添加位置嵌入来引入位置信息,这些位置嵌入可以是固定的或可训练的。位置嵌入在第一个Transformer块之前被添加到块令牌中,然后这些令牌被送入Transformer块的堆叠中。

        此外,引入了一个可训练的类令牌(class token),它在第一层之前被附加到patch令牌上。这个类令牌通过Transformer层进行传播,并最终通过一个线性层进行投影以预测类别。这个类令牌是从自然语言处理(NLP)领域继承而来的,与计算机视觉中通常用于预测类别的池化层不同。

        因此,Transformer处理的是一批(N+1)个维度为D的令牌,其中只有类向量被用于预测输出。这种架构迫使自注意力机制在patch令牌和类令牌之间传播信息:在训练时,监督信号仅来自类嵌入,这迫使模型学习如何有效地将全局图像信息编码到类令牌中,以便进行准确的类别预测。

4. 通过注意力进行蒸馏

        软蒸馏最小化教师模型的softmax和学生模型的softmax之间的Kullback-Leibler分歧。设Z_t为教师模型的分数,Z_s为学生模型的分数。我们用 \tau 表示蒸馏的温度,\lambda 表示在真正值 y 上平衡Kullback-Leibler散度损失(KL)和交叉熵(LCE)的系数,ψ表示softmax函数。蒸馏的目的是

\begin{aligned}\mathcal{L}_{\mathrm{global}}&=(1-\lambda)\mathcal{L}_{\mathrm{CE}}(\psi(Z_{\mathrm{s}}),y)\\&+\lambda\tau^{2}\mathrm{KL}(\psi(Z_{\mathrm{s}}/\tau),\psi(Z_{\mathrm{t}}/\tau)).\quad(2)\end{aligned} 

        硬蒸馏把教师模型的预测值作为一个真正的标签,设{y_{\mathrm{t}} = \mathrm{argmax}_{c}Z_{\mathrm{t}}(c)}是教师模型的预测值,与这个硬标签蒸馏相关的目标是 \mathcal{L}_{\mathrm{global}}^{\mathrm{hardDistill}}=\frac12\mathcal{L}_{\mathrm{CE}}(\psi(Z_s),y)+\frac12\mathcal{L}_{\mathrm{CE}}(\psi(Z_s),y_\mathrm{t}).\text{(3)}对于给定的图像,与教师相关的硬标签可能会随着具体的数据增强而变化。

4.1 标签平滑

        在训练深度学习模型时,硬标签(即完全确定的类别标签)可以通过标签平滑技术转换为软标签。其核心思想是将真实标签的概率设置为1-ε,而将剩余的ε概率均匀分配给其他所有类别。在所有实验中,当使用真实标签时,固定ε的值为0.1。需要注意的是,对于由教师模型提供的伪标签(例如在硬蒸馏中),并不进行平滑处理,因为这些伪标签本身已经包含了某种程度的不确定性。

4.2 蒸馏令牌(Distillation Token)

        我们在初始嵌入(包括块嵌入和类别令牌)中引入了一个新的令牌,即蒸馏令牌。蒸馏令牌的使用方式与类别令牌类似:它通过自注意力机制与其他嵌入进行交互,并在网络的最后一层被输出。蒸馏令牌的目标损失由蒸馏损失组件给出。蒸馏嵌入使得我们的模型能够像常规蒸馏那样从教师模型的输出中学习,同时与类别嵌入保持互补关系。

4.3 使用蒸馏的微调

        在更高分辨率的微调阶段,同时使用真实标签和教师模型的预测。教师模型的目标分辨率与正在训练的模型相同,这通常通过Touvron等人(2019)的方法从较低分辨率的教师模型中获得。我们也尝试过仅使用真实标签进行微调,但这会降低教师模型带来的好处,并导致性能下降。因此,结合真实标签和教师预测可以更有效地利用教师模型的知识。

4.4 使用联合分类器进行分类

        在测试时,Transformer模型产生的类别嵌入或蒸馏嵌入都可以与线性分类器相关联,用于推断图像标签。参考方法是这两个独立头部的后期融合,即我们将两个分类器的softmax输出相加来做出预测。

4.5 引入蒸馏令牌的操作

        通过引入一个新的蒸馏令牌(distillation token)来实现。这个蒸馏令牌在自注意力层中与类别令牌(class token)和块令牌(patch tokens,通常用于处理图像中的小块)进行交互。这个蒸馏令牌的使用方式与类别令牌相似,但在网络输出时,它的目标是重现由教师模型预测的(硬)标签,而不是真实的标签。

具体来说,这段话的含义包括以下几个方面:

  1. 蒸馏令牌的引入:为了进行知识蒸馏,模型被扩展以包含一个额外的蒸馏令牌。这个令牌被添加到模型的输入中,与类别令牌和图像块令牌一起,通过自注意力机制进行交互。

  2. 蒸馏令牌的作用:蒸馏令牌在模型中扮演着特殊的角色。尽管它的交互方式与类别令牌相似(都通过自注意力机制与其他令牌交互),但在模型训练时,蒸馏令牌的目标是使网络的输出能够尽可能接近教师模型对同一输入数据的预测(即硬标签),而不是真实的标签。

  3. 反向传播学习:无论是类别令牌还是蒸馏令牌,它们作为输入到Transformer模型中的一部分,都是通过学习过程(特别是反向传播算法)来优化其表示的。这意味着,在训练过程中,网络会根据损失函数(包括与真实标签和教师预测相关的部分)来调整这些令牌的表示,以便更好地完成分类和蒸馏任务。

  4. 蒸馏目标:通过让蒸馏令牌学习教师模型的预测,模型能够捕获教师模型中的知识,并将其用于提高自身的性能。这种知识蒸馏技术通常能够帮助学生模型(即正在训练的模型)在保持或提高准确性的同时,减少对数据集标签的依赖,从而提高模型的泛化能力。

5. 模型架构

DeiT模型架构

输入嵌入层(Patch Embedding)

        输入图像首先被划分为固定大小的非重叠图像块(patches),例如16x16的块。每个图像块被展平为一维向量,然后通过一个线性层(或者称为投影层)映射到一个固定的维度,通常是d_model(例如768维),这些嵌入向量被称为patch嵌入。

位置编码(Positional Encoding)

        因为Transformer没有卷积网络的位置信息捕捉能力,所以需要加上位置编码以保留图像块的位置信息。位置编码通常是固定的(可以是可训练的),并且在patch嵌入上逐元素相加。

Transformer编码器(Transformer Encoder)

        包含多个堆叠的Transformer编码器层(例如12层)。每个编码器层由多头自注意力机制(Multi-Head Self-Attention)和前馈神经网络(Feed Forward Neural Network, FFN)组成。层归一化(Layer Normalization)和残差连接(Residual Connection)应用于每个子层。

分类令牌(Class Token)

        类似于BERT中的分类令牌([CLS]),DeiT 在patch嵌入中引入了一个可训练的分类令牌。分类令牌与图像块嵌入一起被输入到Transformer编码器中,通过自注意力机制交互获取全局信息。在经过所有编码器层后,分类令牌包含了整个图像的全局特征。

输出层

        分类令牌经过一个线性层后输出图像分类的预测结果。

蒸馏令牌(Distillation Token)(用于DeiT的蒸馏版本):

        除了分类令牌外,DeiT还引入了一个蒸馏令牌,用于从教师模型(如CNN)中学习知识。蒸馏令牌通过自注意力机制与其他令牌交互,最终通过一个线性层输出用于蒸馏损失计算的预测结果。

  • 8
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值