『论文精读』Data-efficient image Transformers(DeiT)论文解读

本文解读了DeiT论文,一种无需大量预训练数据的高效图像Transformer模型,通过知识蒸馏、更好的超参数设置和数据增强等方法,减少计算资源需求,同时在ImageNet上达到最先进的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

『论文精读』Data-efficient image Transformers(DeiT)论文解读

一. DeiT简介

  • 现有的基于Transformer的分类模型ViT需要 在海量数据上(JFT-300M,3亿张图片) 进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。

在这里插入图片描述

  • DeiT的模型和VIT的模型几乎是相同的,可以理解为本质上是在训一个VIT。
  • better hyperparameter:指的是模型初始化、learning-rate等设置。
  • data augmentation:在只有120万张图片的Imagenet,使用数据增广模拟更多数据。
  • Distillation:知识蒸馏。
  • 三部分的作用分别为:保证模型更好的收敛、可以使用小的数据训练、进一步提升性能。还有一些其他的方式,如:warmup、label smoothing、droppath等。

在这里插入图片描述

  • Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。

在这里插入图片描述

  • 文章贡献如下:
  • 仅使用Transformer,不引入Conv的情况下也能达到SOTA效果。
  • 提出了基于token蒸馏的策略,这种针对transformer的蒸馏方法可以超越原始的蒸馏方法。
  • Deit发现使用Convnet作为教师网络能够比使用Transformer架构取得更好的效果。

在这里插入图片描述

二. 知识蒸馏(knowledge distillation)

  • Knowledge Distillation(KD)最初被Hinton提出,与Label smoothing动机类似,但是KD生成soft label的方式是通过教师网络得到的。
  • KD可以视为将教师网络学到的信息压缩到学生网络中。还有一些工作“Circumventing outlier of autoaugment with knowledge distillation”则将KD视为数据增强方法的一种。
  • KD能够以soft的方式将归纳偏置传递给学生模型,Deit中使用Conv-Based架构作为教师网络,将局部性的假设通过蒸馏方式引入Transformer中,取得了不错的效果。
  • 简单来说就是用teacher模型去训练student模型,通常teacher模型更大而且已经训练好了,student模型是我们当前需要训练的模型。在这个过程中,teacher模型是不训练的。
  • 当teacher模型和student模型拿到相同的图片时,都进行各自的前向,这时teacher模型就拿到了具有分类信息的feature,在进行softmax之前先除以一个参数 τ \tau τ,叫做temperature(蒸馏温度),然后softmax得到soft labels(区别于one-hot形式的hard-label)。
  • student模型也是除以同一个 τ \tau τ,然后softmax得到一个soft-prediction,我们希望student模型的soft-prediction和teacher模型的soft labels尽量接近,使用KLDivLoss进行两者之间的差距度量,计算一个对应的损失teacher loss
  • 在训练的时候,我们是可以拿的到训练图片的真实的ground truth(hard label)的,可以看到上面图中student模型下面一路,就是预测结果和真是标签之间计算交叉熵crossentropy。
  • 链接:损失函数|交叉熵损失函数
  • 然后两路计算的损失:KLDivLoss和CELoss,按照一个加权关系计算得到一个总损失total loss,反向修改参数的时候这个teacher模型是不做训练的,只依据total loss训练student模型。
  • 还可以使用硬蒸馏,对比上面的结构图,哪种更好没有定论。

2.1. KLDivloss

  • KL divergence(KL散度又叫相对熵): 它表示用分布 q ( x ) q(x) q(x) 模拟真实分布 p ( x ) p(x) p(x) 所需要的额外信息。同时也叫KL距离,就是是两个随机分布间距离的度量。
  • 取值范围: [ 0 , + ∞ ] [0, +\infty ] [0,+]当两个分布接近相同的时候KL散度取值为0,当两个分布差异越来越大的时候KL散度值就会越来越大。
    D K L ( p ∣ q ) = H ( p , q ) ⏟ 交叉熵 − H ( p ) ⏟ 信息熵 = − ∑ i = 1 n p ( x i ) log ⁡ q ( x i ) + ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) = ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) q ( x i ) (1) \begin{aligned} {D}_{K L}({p} | {q})&=\underbrace{H(p, q)}_{\text {交叉熵}}-\underbrace{H(p)}_{\text {信息熵}}\\&=-\sum_{i=1}^{n}{p}(x_i) \log {q}(x_i)+\sum_{i=1}^{n} {p}(x_i) \log {p}(x_i) \\ &=\sum_{i=1}^{n} {p}(x_i) \log \frac{{p}(x_i)}{{q}(x_i)}\tag{1} \end{aligned} DKL(pq)=交叉熵 H(p,q)信息熵 H(p)=i=1np(xi)logq(xi)+i=1np(xi)logp(xi)=i=1np(xi)logq(xi)p(xi)(1) 注意: 直观来说,由于 p ( x ) p(x) p(x) 是已知的分布(真实分布), H ( p ) H(p) H(p) 是个常数,交叉熵和KL散度之间相差一个这样的常数(信息熵)
  • 当两个分布完全一致时候,KL散度就等于0。KLDivloss定义和使用方式为:

2.2. 蒸馏温度 τ \tau τ

  • 蒸馏温度 τ \tau τ 的作用,回想之前VIT中在self-attention里面计算 q , k \mathbf {q,k} q,k间的加权因子的时候,计算完了要scale(除以 k k k 的维度),然后再做softmax,然后用它们对 v \mathbf v v 加权相加得到对应的表示向量。
  • 如果是[1.0,20.0,400.0]直接做softamx,那结果是[0.0,0.0,1.0],可见结果完全借鉴第三个引子。而先进行处理(比如除以1000)后变为[0.001,0.02,0.4]时,在做softamx结果为[0.28,0.29,0.42]结果总综合考虑了三部分,这显然是更合理的结果。实际中,看我是更希望结果偏向于更大的值,还是偏向于综合考虑来决定是否使用softmax前输入的预处理。

2.3. distillation in transformer

这一节主要弄清楚,如何在transformer中进行蒸馏操作。

在这里插入图片描述

  • 先说一下,在这DeiT篇论文出来的时候,teacher model使用的是Regnet(一个CNN)
  • 在VIT中时使用class tokens去做分类的,相当于是一个额外的patch,这个patch去学习和别的patch之间的关系,然后连classifier,计算CELoss。在DeiT中为了做蒸馏,又额外加一个distill token,这个distill token也是去学和其他tokens之间的关系,然后连接teacher model计算KLDivLoss,那CELoss和KLDivLoss共同加权组合成一个新的loss取指导student model训练(知识蒸馏中teacher model不训练)。
  • 在预测阶段,class token和distill token分别产生一个结果,然后将其加权(分别0.5),再加在一起,得到最终的结果做预测。

L global  = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) (2) \mathcal{L}_{\text {global }}=(1-\lambda) \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{\mathrm{s}}\right), y\right)+\lambda \tau^2 \mathrm{KL}\left(\psi\left(Z_{\mathrm{s}} / \tau\right), \psi\left(Z_{\mathrm{t}} / \tau\right)\right)\tag{2} Lglobal =(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))(2)

L global  hardill  = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 L C E ( ψ ( Z s ) , y t ) (3) \mathcal{L}_{\text {global }}^{\text {hardill }}=\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y\right)+\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y_{\mathrm{t}}\right)\tag{3} Lglobal hardill =21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt)(3)

三. better hyperparameter

  • DeiT中第二个优化点在于better hyperparameter,也就是更好的参数配置,看看其都包含哪些部分。

在这里插入图片描述

  • 参数初始化方式:truncated normal distribution(截断标准分布)
  • learning-rate:CNN中的结论:当batch size越大的时候,learning rate设置的越大。
  • learning rate decay:cosine,在warm-up阶段lr先线性升上去,然后通过余弦方式lr降下来

四. data augmentation

在这里插入图片描述

  • mixup之后的图片的label不再是单一的label,而是soft-label,比如[cat,dog]=[0.5,0.5]
  • cutmix之后的图片label是按所占据的比例给的,比如[cat,dog]=[0.3,0.7]

在这里插入图片描述

  • randomaug其实是由autoaug来的,autoaug是选取了25中增强策略,每种策略中有两个操作,这两种操作都要被执行。每次为一张图随机从25中策略中选取一种,将这两种操作对该图执行。至于这25中策略是怎么组成的,每种里面的操作的概率是如何确立的,这些是由搜索算法的实现的,总之认为这么搭配有效就行了。对于randomaug,相当于对于autoaug的简化,它是13种增强策略,然后从中一次选取6种策略依次对图片进行操作,完成增强操作。
  • model EMA(Exponential Moving Average)指数滑动平均,使得模型权重更新与一段时间内的历史取值有关。 m t m_{t} mt 是当前的模型权重, m t − 1 m_{t-1} mt1 是上一轮模型权重, θ t \theta_{t} θt为模型当前权重的值,举一个例子:

在这里插入图片描述
在这里插入图片描述

  • 三种更新参数方式的更新参数结果曲线:

在这里插入图片描述

  • 实际使用的时候,设置上面例子中的 β \beta β 值例如为0.99996,保证模型的参数值不会乱动。

五. label smoothing

  • label smoothing:原本hard-label变成soft-label,设置参数,给其余非标签平均一些label概率。
     Label  one hot  = [ 1 , 0 , 0 , 0 , 0 , 0 ]  Label  smoothing  = [ 0.9 , 0.02 , 0.02 , 0.02 , 0.02 , 0.02 ] , α = 0.1 \begin{aligned} & \text { Label }_{\text {one hot }}=[1,0,0,0,0,0] \\ & \text { Label }_{\text {smoothing }}=[0.9,0.02,0.02,0.02,0.02,0.02], \alpha=0.1 \end{aligned}  Label one hot =[1,0,0,0,0,0] Label smoothing =[0.9,0.02,0.02,0.02,0.02,0.02],α=0.1

在这里插入图片描述
在这里插入图片描述

参考文献

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI大模型前沿研究

感谢您的打赏,我会继续努力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值