文章目录
Training data-efficient image transformers & distillation through attention
基本信息
- 论文链接:arxiv
- 发表时间:2020
- 应用场景:图像分类
摘要
存在什么问题 | 解决了什么问题 |
---|---|
1. vit效果虽好,但是所需要的pretrain数据量太大了,所以也就限制了其应用场景。 | 1. 提出了一个pure transformer模型deit,不需要大规模的pretrain数据集,只需在ImageNet上做pretrain就能达到83.1% top1准确率,另外单机8卡训练只需要3天。 |
模型结构
该方法需要借助模型蒸馏实现,所以先要有一个训练好的分类器当做teacher,这个分类器可以使ConvNet,也可以是VIT。
两种蒸馏策略
-
Soft distillation(使用较为广泛)
作用:减小teacher和student输出logits后的softmax向量间的KL散度。
其中ψ代表softmax,τ代表蒸馏时的温度系数。
-
Hard-label distillation(Deit所采用的方法)
其中 y t = a r g m a x c Z t ( c ) y_t=argmax_c^{Z_t(c)} yt=argmaxcZt(c)代表了teacher输出的softmax。即student不仅要和gt做CELoss,还要和teacher输出的softmax做CELoss。
该方法不需要调参,也就没有了温度系数τ。
还有一个改动就是对y做了label smoothing(ε=0.1),那么此时不论是y还是 y t y_t yt就都是soft label了。
Distillation token
和CLS token一样,将distillation token加在网络的最前面,所以此时token总数为:patch_num+2,一同和其他token进行前向传播。最后输出的时候会在distillation token加个FC与teacher输出的soft label做CELoss,从而完成整个训练流程。
作者发现随着网络的深入,cls token和distillation token之间的余弦相似度逐渐增大,但是最后也只有0.93,并没有达到1,也就是说虽然所学习的内容存在类似,但还是有一定的区别。
另外为了验证distillation token确实让网络学到了新的内容,作者将distillation token换成另一个cls token,然后让网络中的两个cls token去学习同一个gt,不论如何初始化,在最后一层两个cls token之间的余弦相似度总会达到0.99,即新加的cls token没有给网络带来任何收益。验证了distillation token有效性。
其他细节
- 训练的时候cls token和distillation token分别通过各自的FC和gt、teacher output的softmax做loss,但是在预测时候会采取两者输出求平均,再取argmax获得预测结果的策略。
实验
用ConvNet作为teacher会比用VIT作为teacher其student模型精度更高。
student模型精度会超过teacher模型的进度,得益于student模型通过distillation token继承了teacher模型的归纳偏置。
验证了Deit预训练的时候同时用label以及teacher输出蒸馏的时候效果是最好的,另外只用teacher蒸馏要比只用gt的模型精度要高。
验证了hard labe distillation是最优的。
Deit和之前SOTA模型间的比较
总结
-
Deit在VIT的基础上,加了两个trick——hard label distillation和distillation token使得VIT模型不再需要像JFT-300M超大规模用于预训练的数据集,仅在ImageNet1k上做pretrain就可以得到一个与VIT效果可比的VIT模型,另外训练时间大幅缩短,是一个简单、高效的trick。
-
Deit和VIT相比用了非常多数据增强方法,并且随着数据增强手段越来越丰富,精度也越来越高。所以VIT相比于ConvNet还是缺少了归纳偏置,需要大数据和数据增强来让模型尝试去学习这些归纳偏置。