DeiT:Training data-efficient image transformers & distillation through attention
https://arxiv.org/pdf/2012.12877.pdf
GitHub - facebookresearch/deit: Official DeiT repository
2020 Dec
现有的基于Transformer的分类模型ViT需要在海量数据上(JFT-300M,3亿张图片)进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。
Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。
DeiT采用的方法:
1)采用合适的训练策略包括optimizer, data augmentation, regularization等,这一块该文主要是在实验部分介绍;
2)采用蒸馏的方式,结合teacher model来引导基于Transformer的DeiT更好地学习
蒸馏
假设有一个性能很好的分类器(它可以是CNN,也可以是Transformer)作为teacher model。
在ViT的基础上,加上一个distillation token(右下角位置),就是维度+1,然后在self-attention layers中跟class token,patch token不断交互。它跟左下角的class token很像,唯一的区别在于,class token的目标是跟真实的label一致,而distillation token是要跟teacher model预测的label一致。
蒸馏分两种,一种是软蒸馏,另一种是硬蒸馏。
下式为软蒸馏:zs和zt分别是student model和teacher model的输出,φ表示sotfmax函数,λ和τ是超参数。
硬蒸馏:CE表示交叉熵。
软蒸馏是限制student和teacher的模型输出类别分布尽可能接近,而硬蒸馏是限制两种模型输出的类别标签尽可能接近。
训练的时候就用软蒸馏或硬蒸馏的loss进行训练。作者发现一个有趣的现象,class token和distillation token是朝着不同的方向收敛的,对各个layer的这两个token计算余弦相似度,平均值只有0.06,不过随着网络会越来越大,在最后一层是0.93,也就是相似但不相同。作者做了个实验来验证这个确实distillation token有给模型add something。就是简单地增加一个class token来代替distillation token,然后发现,即使对这两个class token进行独立的随机初始化,它们最终会收敛到同一个向量(余弦相似度为0.999),且性能没有明显提升。
三种版本的参数设置
DeiT-B是较大的模型,结构与ViT完全相同。DeiT-Ti和DeiT-S是两个较小的模型,区别在于heads数目和embedding dimension不同。
teacher的选择
CNN效果更好,这可能是因为transformer可以学到CNN的归纳假设。
CNN是有inductive bias的,例如局部感受野,参数共享等,这些设计比较适应于图像任务,这里将CNN作为teacher,可以通过蒸馏,使得Transformer学习得到CNN的inductive bias,从而提升Transformer对图像任务的处理能力。
m小标表示使用了蒸馏策略的网络模型。
↑384表示student在224*224图像上进行预训练,然后在384*384图像上进行fine-tune。
RegNetY-16GF:《Designing Network Design Spaces》CVPR2020
https://arxiv.org/abs/2003.13678
distillation效果比较
hard distillation性能比soft distillation更好。
在测试性能的时候,是用class token还是distillation token还是都用呢?在pretrain上测试的时候,两个一起用性能更佳。只用一个的时候,distillation token性能略好于class token这可能是因为distillation token里有更多从CNN中学到的归纳假设。还有一个很有趣的现象是,Transformer的性能比它teacher CNN还好。
模型性能对比
训练策略
初始化与超参数
Optimizer
AdamW比SGD性能更好。
数据增强
Rand-Augment:
采用与基本数据增广方式相同的一系列方法如:identity、rotate…具体的增广手段如下表所示共K=14种。但本文并非目的是为了让CNN网络学习如何对针对不同的数据集进行增广,而是随机选择数据增广的方式。即随机选择一种手段对原始数据进行变换,并调整它们的大小。
AutoAug:根据特定数据集搜索最优增强方法
Mixup:线性插值,:将随机的两张样本按比例混合,分类的结果按比例分配
CutMix:就是将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配
Transformer的训练需要大量的数据,想要在不太大的数据集上取得好性能,就需要大量的数据增强,以实现data-efficient training。如图9所示,几乎所有评测过的数据增强的方法都能提升性能。
重点采用的数据扩充方式是Rand-Augment, Mixup, CutMix。
Regularization
Repeated Aug:
《Multigrain: a unified image embedding for classes and instances》
《Augment your batch: Improving generalization through instance repeation》
Exp Moving Avg:EMA 滑动平均
采用Random Erasing和Stochastic depth等方式有助于模型的收敛,尤其是采用较深的模型时。
Random Erasing:随机选择一个区域,然后采用随机值进行覆盖。
Stochastic depth: 随机失活一些卷积层,只保留 shortcut 通路的方式随机跳过 一些 Residual Blocks