系列文章目录
论文阅读笔记 (1):DeepLabv3
论文阅读笔记 (2) :STA手势识别
论文阅读笔记 (3): ST-GCN
论文阅读笔记(5):图上的光谱网路和深度局部链接网络
论文阅读笔记(6): GNN-快速局部光谱滤波
论文阅读笔记(8): 图卷积半监督分类
论文阅读笔记 (9) CrossViT: Cross-Attention Transformer 阅读笔记
Abstract
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. These high- performing vision transformers are pretrained with hundreds of millions of images using a large infrastructure, thereby limiting their adoption.
直接了当的提出要解决的问题,那就是之前的transformer 性能好,但是需要很大的数据集进行预训练,这样导致应用受限。本文就要是处理这个问题的。
In this work, we produce competitive convolution-free transformers by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop) on ImageNet with no external data.
表达本文结果很棒,从训练师时间,效果上说明。
More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
表达了本文的重点工作,也是我们要弄明白的。
1. Introduction
- VIT 介绍。
- 应用的方法 timm中的改进,repeated augmentation.也就是用数据增强的方法改进效果。
- 解决的另一个问题:如何蒸馏这些模型。
how to distill these models? We introduce a token-based strategy, specific to transformers and denoted by DeiT⚗, and show that it advantageously replaces the usual distillation. - 贡献:
- 无序额外大型数据进行训练即可达到sta.Our two new models DeiT-S and DeiT-Ti have fewer param- eters and can be seen as the counterpart of ResNet-50 and ResNet-18.
- 新的蒸馏过程
- 蒸馏对象为卷积和transformer 的比较。
- 迁移学习。
2. Related Work
- Image Classification.
- The Transformer architecture
- Knowledge Distillation
知识蒸馏可以迁移归纳偏差,参考论文:Transferring inductive biases through knowledge distillation
所以本文将卷积网络的归纳偏差通过蒸馏引入transformer 中。
3. Methods
3.1 Vision transformer: overview
- MSA
- Transformer block for images : 将图像转换为pathes,以及位置编码的incorporate, 因为图像块本身是invariant的。
- the class token. 也就是CLS token。
- Fixing the positional encoding across resolutions: 用小分辨率图像进行预训练,用高分辨率图像微调,同时保持patch size 不变,那么num of patches 会发生变化,所以positon encoding 需要变化。
3.2 Distillation through attention
解决的问题:
how to learn a transformer by exploiting this teacher
比较内容
- hard distillation versus soft distillation
- classical distillation versus the distillation token.
3.2.1 Soft distillation
KL 散度来做学生和老师模型的损失:
其中
λ
\lambda
λ 控制散度损失和交叉熵损失比率,
Z
s
Z_s
Zs 为学生模型输出的
l
o
g
i
t
s
logits
logits,
τ
\tau
τ 为温度,
ψ
\psi
ψ为
s
o
f
t
m
a
x
softmax
softmax函数。
KL散度是什么东西呢:
K
L
(
A
∣
∣
B
)
=
∑
i
A
(
i
)
l
o
g
A
(
i
)
B
(
i
)
KL(A||B) = \sum_{i}A(i)log{\frac{A(i)}{B(i)}}
KL(A∣∣B)=i∑A(i)logB(i)A(i)
可以度量两个随机变量的距离,KL 散度就是两个概率分布A,B差别的非对称性的度量。A表示真实分布,B表示理论分布或者模型。目标就是要让模型分布尽可能与真实分布相同,所以 B是老师,A是学生。
值越小,表示越接近,因此可以直接用来作为损失函数,KL散度也成为相对熵损失。
温度 τ \tau τ的作用:
s
o
f
t
m
a
x
softmax
softmax 对i 类别进行作用 :
q
i
=
e
x
p
(
z
i
/
τ
)
∑
j
(
z
j
/
τ
)
q_{i} = \frac{exp(z_i/\tau)}{\sum_j{(z_j/\tau)}}
qi=∑j(zj/τ)exp(zi/τ)
对
z
i
z_i
zi使用1-5 温度,来查看softmax 后变化的效果:
import matplotlib.pyplot as plt
x0 = [0.2104, 0.2325, 0.3468, 0.2104]
x1 = [0.2307, 0.2425, 0.2962, 0.2307]
x2 = [0.2372, 0.2453, 0.2803, 0.2372]
x3 = [0.2405, 0.2466, 0.2725, 0.2405]
x4 = [0.2424, 0.2473, 0.2679, 0.2424]
plt.plot(range(4), x0, marker='o', label = "t=1")
plt.plot(range(4), x1, marker='o', label = "t=2")
plt.plot(range(4), x2, marker='o', label = "t=3")
plt.plot(range(4), x3, marker='o', label = "t=4")
plt.plot(range(4), x4, marker='o', label = "t=5")
plt.legend()
plt.show()
从图中可以看出,温度越大,变化越平缓。
也就是将原来容易区分的类别的距离拉近了,从而提高了分类难度。
放到公式中:
就是降低老师模型的平缓度,同时降低学生的平欢度,两者都是为了提高学生网络的学习难度的。
3.2.2 Hard label distillation
上边的是soft label 进行蒸馏,也就是通过学习学生网络最后的
l
o
g
i
t
logit
logit分布来进行的。而另一种就是hard label distillation,也就是学习不再是老师网络的
l
o
g
i
t
logit
logit,而是对
l
o
g
i
t
logit
logit进行
s
o
f
t
m
a
x
softmax
softmax,并且进行
a
r
g
m
a
x
(
.
)
argmax(.)
argmax(.)的结果,也就是老师预测的类别
y
′
y'
y′,这样算是函数的后一项就不再是KL 散度(相对熵),而变成了与第一项相同的交叉熵:
其中的
y
t
y_t
yt就是教师网络预测类别。
The teacher prediction
y
t
y_t
yt plays the same role as the true label
y
y
y.
虽然如此,但是本文中还用了 label smoothing的方法,可以看做soft label的模拟,true label:
1
−
ϵ
1-\epsilon
1−ϵ ,其他标签分享权重
ϵ
\epsilon
ϵ,本文
ϵ
=
0.1
\epsilon=0.1
ϵ=0.1.
3.2.3 Distillation token
从图上来看是很明了的,就是添加一个类似cls 的token,然后这个token 的输出是用来计算hard label的损失的,CLS token 依然是计算与ground truth 的损失的。
文中详细说明了这两种方式的不同。
3.2.4 Fine-tuning with distillation
微调过程依然是ground truth 和 hard label一起来进行微调,教师网络也是先用低分辨率来进行训练,用高分辨率再次微调得到,然后用来训练学生网络。
3.2.5 Classification with our approach: joint classifier
处理CLS ,和Distillation Token 预测融合的问题。
4. Experiments
4.1 Transformer Models
VIT 中的参数:
本文的参数:DIT-B 与VIT-B 对应,D=768, H=12, d = 64.
另外两种较小的 Dit-S, Dit-T, 都是保持d不变,改变head 数量。
4.2 Distillation
后边实验使用的老师都是RegNetY-16GF,直接训练,先在小分辨率预训练,然后在大分辨率微调,后边为384分辨率微调。
比较soft label 与 hard label,可以看出 用soft label 竟然没有改善。相反hard 却改善很多。
与卷积更加相关。(但问题是本来就是从卷积作为老师得到的模型啊?????)
4.3 Efficiency vs Accuracy
速度和准确度的判别,这里速度用了images/s 这种方式来表示。
4.4 Transfer learning
4.5 消融实验
5. Conclusions
论文思路:用数据增强是模型在ImageNet 上得到较好的结果(中间夹带蒸馏过程),然后在fintune 一波再次提升。其实模型结构上就比价VIT 多了个蒸馏token ??? 从消融实验部分看,实验量是很多。。。。。