【论文阅读】【ViT系列】DeiT:数据高效的图像transformers的训练&通过注意力的蒸馏

论文:Training data-efficient image transformers & distillation through attention
代码:https://github.com/facebookresearch/deit

1 主要贡献

  • 使用ImageNet数据(无外部数据)将无卷积层的transformer网络训练到了SOTA水平,训练时间短;
    (原始的vision transformer需要使用非公开的JFT-300M数据集进行大量训练,无法复现)
  • 提出了基于distillation token的蒸馏机制,distillation token用于学习教师网络的预测结果;
  • 图像transformers从卷积网络中学习的效果优于从其他transformers中学习;
  • 在ImageNet上预学习的网络在多个下游基准中也有竞争力。

性能对比:top-1准确率 vs. 网络吞吐量(仅在ImageNet1k上训练)——使用transformer专用蒸馏方法训练的模型最优。

2 原理

2.1 Vision Transformer

回顾原始ViT的原理:

  • 多头自注意力层(MSA)的设计:transformer;
  • 针对图片的tranformer block:FFN + MSA;
  • class token:来自NLP;
  • 训练时用低分辨率图片,微调时用高分辨率图片,改变分辨率时插值改变位置编码。

2.2 Distillation through attention

主要介绍了软蒸馏、硬蒸馏两种损失函数,和Distillation token结构。

2.2.1 软蒸馏

L g l o b a l = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) \mathcal{L}_{global}=(1-\lambda)\mathcal{L}_{CE}(\psi(Z_s),y)+\lambda\tau^2\mathrm{KL}(\psi(Z_s/\tau),\psi(Z_t/\tau)) Lglobal=(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))
λ , τ \lambda,\tau λ,τ是超参数, y y y是ground truth, ψ \psi ψ是softmax函数, Z s , Z t Z_s,Z_t Zs,Zt分别是学生模型、教师模型的输出, L C E \mathcal{L}_{CE} LCE是交叉熵损失, K L \mathrm{KL} KL是KL散度。

2.2.2 硬蒸馏

L g l o b a l h a r d D i s t i l l = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 τ 2 L C E ( ψ ( Z s ) , y t ) \mathcal{L}_{global}^{hardDistill}=\frac12\mathcal{L}_{CE}(\psi(Z_s),y)+\frac12\tau^2\mathcal{L}_{CE}(\psi(Z_s),y_t) LglobalhardDistill=21LCE(ψ(Zs),y)+21τ2LCE(ψ(Zs),yt)
y t y_t yt是教师模型的预测结果。

2.2.3 Distillation token

如图,在patches中加入与class token类似的distillation token,两者的通过网络时的计算方式相同,区别在于class token目标是重现ground truth标签,而distillation token目标是重现教师模型的预测。
输出时的distillation token与class token余弦相似度为0.93,表明两者的目标相似但不相同。
当用一个class token替换distillation token时,两个class token输出的余弦相似度为0.999,网络性能与一个class token相近,而加入distillation token的网络性能明显提升。这表明distillation token的设定是有效的。

2.2.4 微调和分类

对蒸馏的结构进行微调,需要将教师网络的目标分辨率提升。
分类结果由class和distillation输出的softmax之和决定。

3 实验

3.1 Transformer模型

定义了与ViT-B参数相同的DeiT-B模型,和更小的DeiT-S、DeiT-Ti模型,超参数如下:

3.2 蒸馏

3.2.1 教师模型对比

实验发现RegNetY-16GF是效果最好的教师模型,后续实验默认选择。卷积网络教师优于transformer教师,可能因为继承了卷积网络的bias。

3.2.2 蒸馏方法对比

硬蒸馏优于软蒸馏,class和distillation token同时使用优于单独使用一个。

3.2.3 教师网络引入的归纳偏差

下表为不同分类器中分类结果不同的比例。结果表明,使用distillation embedding的分类器结果与卷积网络更相似,使用class embedding的分类器结果与无蒸馏的DeiT更相似,两者结合的分类器结果介于两者之间。

3.2.4 epochs数量

300epochs后使用distillation token的网络已经占优,且性能仍未饱和,继续训练可以提升准确率。

3.3 效率vs准确率:与卷积网络对比

使用timm库的实现,对比DeiT、ViT和卷积网络的准确率和效率(吞吐量)。

3.4 迁移学习:在下游任务的表现

在其他数据集的预测表现:

3.5 训练细节和消融

消融实验:adamw优化器优于SGD,各种数据增强方法几乎都有效(除了dropout)。

DeiT对参数初始化相对敏感,使用截断的正态分布进行初始化。DeiT对优化超参数很敏感。
ViT-B和DeiT-B的训练超参数如下表:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值