【论文笔记9】DeiT: Training data efficient image transformers & distillation through attention.阅读笔记

系列文章目录

论文阅读笔记 (1):DeepLabv3
论文阅读笔记 (2) :STA手势识别
论文阅读笔记 (3): ST-GCN
论文阅读笔记(5):图上的光谱网路和深度局部链接网络
论文阅读笔记(6): GNN-快速局部光谱滤波
论文阅读笔记(8): 图卷积半监督分类
论文阅读笔记 (9) CrossViT: Cross-Attention Transformer 阅读笔记


Abstract

Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. These high- performing vision transformers are pretrained with hundreds of millions of images using a large infrastructure, thereby limiting their adoption.
直接了当的提出要解决的问题,那就是之前的transformer 性能好,但是需要很大的数据集进行预训练,这样导致应用受限。本文就要是处理这个问题的。
In this work, we produce competitive convolution-free transformers by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop) on ImageNet with no external data.
表达本文结果很棒,从训练师时间,效果上说明。
More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
表达了本文的重点工作,也是我们要弄明白的。

1. Introduction

  • VIT 介绍。
  • 应用的方法 timm中的改进,repeated augmentation.也就是用数据增强的方法改进效果。
  • 解决的另一个问题:如何蒸馏这些模型。
    how to distill these models? We introduce a token-based strategy, specific to transformers and denoted by DeiT⚗, and show that it advantageously replaces the usual distillation.
  • 贡献:
    • 无序额外大型数据进行训练即可达到sta.Our two new models DeiT-S and DeiT-Ti have fewer param- eters and can be seen as the counterpart of ResNet-50 and ResNet-18.
    • 新的蒸馏过程
    • 蒸馏对象为卷积和transformer 的比较。
    • 迁移学习。

2. Related Work

3. Methods

3.1 Vision transformer: overview

  • MSA
  • Transformer block for images : 将图像转换为pathes,以及位置编码的incorporate, 因为图像块本身是invariant的。
  • the class token. 也就是CLS token。
  • Fixing the positional encoding across resolutions: 用小分辨率图像进行预训练,用高分辨率图像微调,同时保持patch size 不变,那么num of patches 会发生变化,所以positon encoding 需要变化。

3.2 Distillation through attention

解决的问题
how to learn a transformer by exploiting this teacher

比较内容

  • hard distillation versus soft distillation
  • classical distillation versus the distillation token.

3.2.1 Soft distillation

KL 散度来做学生和老师模型的损失:
在这里插入图片描述
其中 λ \lambda λ 控制散度损失和交叉熵损失比率, Z s Z_s Zs 为学生模型输出的 l o g i t s logits logits τ \tau τ 为温度, ψ \psi ψ s o f t m a x softmax softmax函数。

KL散度是什么东西呢:

K L ( A ∣ ∣ B ) = ∑ i A ( i ) l o g A ( i ) B ( i ) KL(A||B) = \sum_{i}A(i)log{\frac{A(i)}{B(i)}} KL(AB)=iA(i)logB(i)A(i)
可以度量两个随机变量的距离,KL 散度就是两个概率分布A,B差别的非对称性的度量。A表示真实分布,B表示理论分布或者模型。目标就是要让模型分布尽可能与真实分布相同,所以 B是老师,A是学生。
值越小,表示越接近,因此可以直接用来作为损失函数,KL散度也成为相对熵损失。

温度 τ \tau τ的作用:

s o f t m a x softmax softmax 对i 类别进行作用 :
q i = e x p ( z i / τ ) ∑ j ( z j / τ ) q_{i} = \frac{exp(z_i/\tau)}{\sum_j{(z_j/\tau)}} qi=j(zj/τ)exp(zi/τ)
在这里插入图片描述
z i z_i zi使用1-5 温度,来查看softmax 后变化的效果:

import matplotlib.pyplot as plt
x0 = [0.2104, 0.2325, 0.3468, 0.2104]
x1 = [0.2307, 0.2425, 0.2962, 0.2307]
x2 = [0.2372, 0.2453, 0.2803, 0.2372]
x3 = [0.2405, 0.2466, 0.2725, 0.2405]
x4 = [0.2424, 0.2473, 0.2679, 0.2424]

plt.plot(range(4), x0, marker='o', label = "t=1")
plt.plot(range(4), x1, marker='o', label = "t=2")
plt.plot(range(4), x2, marker='o', label = "t=3")
plt.plot(range(4), x3, marker='o', label = "t=4")
plt.plot(range(4), x4, marker='o', label = "t=5")
plt.legend()
plt.show()

在这里插入图片描述
从图中可以看出,温度越大,变化越平缓
也就是将原来容易区分的类别的距离拉近了,从而提高了分类难度。
放到公式中:
在这里插入图片描述
就是降低老师模型的平缓度,同时降低学生的平欢度,两者都是为了提高学生网络的学习难度的。

3.2.2 Hard label distillation

上边的是soft label 进行蒸馏,也就是通过学习学生网络最后的 l o g i t logit logit分布来进行的。而另一种就是hard label distillation,也就是学习不再是老师网络的 l o g i t logit logit,而是对 l o g i t logit logit进行 s o f t m a x softmax softmax,并且进行 a r g m a x ( . ) argmax(.) argmax(.)的结果,也就是老师预测的类别 y ′ y' y,这样算是函数的后一项就不再是KL 散度(相对熵),而变成了与第一项相同的交叉熵:
在这里插入图片描述
其中的 y t y_t yt就是教师网络预测类别。
The teacher prediction y t y_t yt plays the same role as the true label y y y.
虽然如此,但是本文中还用了 label smoothing的方法,可以看做soft label的模拟,true label: 1 − ϵ 1-\epsilon 1ϵ ,其他标签分享权重 ϵ \epsilon ϵ,本文 ϵ = 0.1 \epsilon=0.1 ϵ=0.1.

3.2.3 Distillation token

在这里插入图片描述
从图上来看是很明了的,就是添加一个类似cls 的token,然后这个token 的输出是用来计算hard label的损失的,CLS token 依然是计算与ground truth 的损失的。
文中详细说明了这两种方式的不同。

3.2.4 Fine-tuning with distillation

微调过程依然是ground truth 和 hard label一起来进行微调,教师网络也是先用低分辨率来进行训练,用高分辨率再次微调得到,然后用来训练学生网络。

3.2.5 Classification with our approach: joint classifier

处理CLS ,和Distillation Token 预测融合的问题。

4. Experiments

4.1 Transformer Models

VIT 中的参数:
在这里插入图片描述
本文的参数:DIT-B 与VIT-B 对应,D=768, H=12, d = 64.
另外两种较小的 Dit-S, Dit-T, 都是保持d不变,改变head 数量。
在这里插入图片描述

4.2 Distillation

在这里插入图片描述
在这里插入图片描述
后边实验使用的老师都是RegNetY-16GF,直接训练,先在小分辨率预训练,然后在大分辨率微调,后边为384分辨率微调。
在这里插入图片描述
比较soft label 与 hard label,可以看出 用soft label 竟然没有改善。相反hard 却改善很多。
在这里插入图片描述
与卷积更加相关。(但问题是本来就是从卷积作为老师得到的模型啊?????

4.3 Efficiency vs Accuracy

在这里插入图片描述
速度和准确度的判别,这里速度用了images/s 这种方式来表示。

4.4 Transfer learning

在这里插入图片描述

4.5 消融实验

在这里插入图片描述
在这里插入图片描述

5. Conclusions

论文思路:用数据增强是模型在ImageNet 上得到较好的结果(中间夹带蒸馏过程),然后在fintune 一波再次提升。其实模型结构上就比价VIT 多了个蒸馏token ??? 从消融实验部分看,实验量是很多。。。。。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值