论文:Training data-efficient image transformers & distillation through attention
代码:https://github.com/facebookresearch/deit
目录
1 主要贡献
- 使用ImageNet数据(无外部数据)将无卷积层的transformer网络训练到了SOTA水平,训练时间短;
(原始的vision transformer需要使用非公开的JFT-300M数据集进行大量训练,无法复现) - 提出了基于distillation token的蒸馏机制,distillation token用于学习教师网络的预测结果;
- 图像transformers从卷积网络中学习的效果优于从其他transformers中学习;
- 在ImageNet上预学习的网络在多个下游基准中也有竞争力。
性能对比:top-1准确率 vs. 网络吞吐量(仅在ImageNet1k上训练)——使用transformer专用蒸馏方法训练的模型最优。
2 原理
2.1 Vision Transformer
回顾原始ViT的原理:
- 多头自注意力层(MSA)的设计:transformer;
- 针对图片的tranformer block:FFN + MSA;
- class token:来自NLP;
- 训练时用低分辨率图片,微调时用高分辨率图片,改变分辨率时插值改变位置编码。
2.2 Distillation through attention
主要介绍了软蒸馏、硬蒸馏两种损失函数,和Distillation token结构。
2.2.1 软蒸馏
L
g
l
o
b
a
l
=
(
1
−
λ
)
L
C
E
(
ψ
(
Z
s
)
,
y
)
+
λ
τ
2
K
L
(
ψ
(
Z
s
/
τ
)
,
ψ
(
Z
t
/
τ
)
)
\mathcal{L}_{global}=(1-\lambda)\mathcal{L}_{CE}(\psi(Z_s),y)+\lambda\tau^2\mathrm{KL}(\psi(Z_s/\tau),\psi(Z_t/\tau))
Lglobal=(1−λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))
λ
,
τ
\lambda,\tau
λ,τ是超参数,
y
y
y是ground truth,
ψ
\psi
ψ是softmax函数,
Z
s
,
Z
t
Z_s,Z_t
Zs,Zt分别是学生模型、教师模型的输出,
L
C
E
\mathcal{L}_{CE}
LCE是交叉熵损失,
K
L
\mathrm{KL}
KL是KL散度。
2.2.2 硬蒸馏
L
g
l
o
b
a
l
h
a
r
d
D
i
s
t
i
l
l
=
1
2
L
C
E
(
ψ
(
Z
s
)
,
y
)
+
1
2
τ
2
L
C
E
(
ψ
(
Z
s
)
,
y
t
)
\mathcal{L}_{global}^{hardDistill}=\frac12\mathcal{L}_{CE}(\psi(Z_s),y)+\frac12\tau^2\mathcal{L}_{CE}(\psi(Z_s),y_t)
LglobalhardDistill=21LCE(ψ(Zs),y)+21τ2LCE(ψ(Zs),yt)
y
t
y_t
yt是教师模型的预测结果。
2.2.3 Distillation token
如图,在patches中加入与class token类似的distillation token,两者的通过网络时的计算方式相同,区别在于class token目标是重现ground truth标签,而distillation token目标是重现教师模型的预测。
输出时的distillation token与class token余弦相似度为0.93,表明两者的目标相似但不相同。
当用一个class token替换distillation token时,两个class token输出的余弦相似度为0.999,网络性能与一个class token相近,而加入distillation token的网络性能明显提升。这表明distillation token的设定是有效的。
2.2.4 微调和分类
对蒸馏的结构进行微调,需要将教师网络的目标分辨率提升。
分类结果由class和distillation输出的softmax之和决定。
3 实验
3.1 Transformer模型
定义了与ViT-B参数相同的DeiT-B模型,和更小的DeiT-S、DeiT-Ti模型,超参数如下:
3.2 蒸馏
3.2.1 教师模型对比
实验发现RegNetY-16GF是效果最好的教师模型,后续实验默认选择。卷积网络教师优于transformer教师,可能因为继承了卷积网络的bias。
3.2.2 蒸馏方法对比
硬蒸馏优于软蒸馏,class和distillation token同时使用优于单独使用一个。
3.2.3 教师网络引入的归纳偏差
下表为不同分类器中分类结果不同的比例。结果表明,使用distillation embedding的分类器结果与卷积网络更相似,使用class embedding的分类器结果与无蒸馏的DeiT更相似,两者结合的分类器结果介于两者之间。
3.2.4 epochs数量
300epochs后使用distillation token的网络已经占优,且性能仍未饱和,继续训练可以提升准确率。
3.3 效率vs准确率:与卷积网络对比
使用timm库的实现,对比DeiT、ViT和卷积网络的准确率和效率(吞吐量)。
3.4 迁移学习:在下游任务的表现
在其他数据集的预测表现:
3.5 训练细节和消融
消融实验:adamw优化器优于SGD,各种数据增强方法几乎都有效(除了dropout)。
DeiT对参数初始化相对敏感,使用截断的正态分布进行初始化。DeiT对优化超参数很敏感。
ViT-B和DeiT-B的训练超参数如下表: