[Transformer] DeiT:Training data-efficient image transformers & distillation through attention

DeiT:Training data-efficient image transformers & distillation through attention

https://arxiv.org/pdf/2012.12877.pdf

GitHub - facebookresearch/deit: Official DeiT repository

2020 Dec

现有的基于Transformer的分类模型ViT需要在海量数据上(JFT-300M3亿张图片)进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。

Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。

 

DeiT采用的方法:

1)采用合适的训练策略包括optimizer, data augmentation, regularization等,这一块该文主要是在实验部分介绍;

2)采用蒸馏的方式,结合teacher model来引导基于Transformer的DeiT更好地学习

蒸馏

假设有一个性能很好的分类器(它可以是CNN,也可以是Transformer)作为teacher model。

在ViT的基础上,加上一个distillation token(右下角位置),就是维度+1,然后在self-attention layers中跟class token,patch token不断交互。它跟左下角的class token很像,唯一的区别在于,class token的目标是跟真实的label一致,而distillation token是要跟teacher model预测的label一致。

 

 

蒸馏分两种,一种是软蒸馏,另一种是硬蒸馏。

下式为软蒸馏:zszt分别是student modelteacher model的输出,φ表示sotfmax函数,λ和τ是超参数。

硬蒸馏:CE表示交叉熵。

 

软蒸馏是限制studentteacher的模型输出类别分布尽可能接近,而硬蒸馏是限制两种模型输出的类别标签尽可能接近。

训练的时候就用软蒸馏或硬蒸馏的loss进行训练。作者发现一个有趣的现象,class tokendistillation token是朝着不同的方向收敛的,对各个layer的这两个token计算余弦相似度,平均值只有0.06,不过随着网络会越来越大,在最后一层是0.93,也就是相似但不相同。作者做了个实验来验证这个确实distillation token有给模型add something。就是简单地增加一个class token来代替distillation token,然后发现,即使对这两个class token进行独立的随机初始化,它们最终会收敛到同一个向量(余弦相似度为0.999),且性能没有明显提升。

 

三种版本的参数设置

DeiT-B是较大的模型,结构与ViT完全相同。DeiT-TiDeiT-S是两个较小的模型,区别在于heads数目和embedding dimension不同。

 

teacher的选择

CNN效果更好,这可能是因为transformer可以学到CNN的归纳假设。

CNN是有inductive bias的,例如局部感受野,参数共享等,这些设计比较适应于图像任务,这里将CNN作为teacher,可以通过蒸馏,使得Transformer学习得到CNNinductive bias,从而提升Transformer对图像任务的处理能力。

m小标表示使用了蒸馏策略的网络模型。

384表示student224*224图像上进行预训练,然后在384*384图像上进行fine-tune

RegNetY-16GF:《Designing Network Design SpacesCVPR2020

https://arxiv.org/abs/2003.13678

 

 

distillation效果比较

hard distillation性能比soft distillation更好。

在测试性能的时候,是用class token还是distillation token还是都用呢?在pretrain上测试的时候,两个一起用性能更佳。只用一个的时候,distillation token性能略好于class token这可能是因为distillation token里有更多从CNN中学到的归纳假设。还有一个很有趣的现象是,Transformer的性能比它teacher CNN还好。

 

模型性能对比

 

训练策略

初始化与超参数

Optimizer

AdamW比SGD性能更好。

数据增强

Rand-Augment:

采用与基本数据增广方式相同的一系列方法如:identity、rotate…具体的增广手段如下表所示共K=14种。但本文并非目的是为了让CNN网络学习如何对针对不同的数据集进行增广,而是随机选择数据增广的方式。即随机选择一种手段对原始数据进行变换,并调整它们的大小。

AutoAug:根据特定数据集搜索最优增强方法

Mixup:线性插值,:将随机的两张样本按比例混合,分类的结果按比例分配

CutMix:就是将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配

Transformer的训练需要大量的数据,想要在不太大的数据集上取得好性能,就需要大量的数据增强,以实现data-efficient training。如图9所示,几乎所有评测过的数据增强的方法都能提升性能。

重点采用的数据扩充方式是Rand-Augment, Mixup, CutMix。

Regularization

Repeated Aug:

《Multigrain: a unified image embedding for classes and instances》

《Augment your batch: Improving generalization through instance repeation》

Exp Moving Avg:EMA 滑动平均

采用Random Erasing和Stochastic depth等方式有助于模型的收敛,尤其是采用较深的模型时。

Random Erasing:随机选择一个区域,然后采用随机值进行覆盖。

Stochastic depth: 随机失活一些卷积层,只保留 shortcut 通路的方式随机跳过 一些 Residual Blocks

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值