Patch Slimming for Efficient Vision Transformers简记
文章目录
参考
剪枝流程
- 稀疏正则训练
- 剪枝,剪去不重要的部分
- finetune微调
主要思想
- 作者认为,attention map中的元素反应的是两个patch之间的相似度(最大的还是与自己的相似度),此外并不是所有patch都能提供充足的信息,比如下面这张,其实只需要很少的几个特定的patch就可以判断是不是狗,而其他的patch其实信息量很少
- 因此作者提出对patch进行剪枝,大致的流程就是从最后一层开始反推减去的patch
理论支持
-
为避免符号混乱,这里贴一贴作者提到的Vision Transformer公式
- H is the number of heads,h表示多头的index
- d is embedding dimension
- KQV的得到方式就是✖️对应的W权重(也有的通过FC得到的,我突然想通了,FC其实就可以表是为✖️对应的W权重)
- Z l ′ = M S A ( Z l − 1 ′ ) Z^{'}_l = MSA(Z^{'}_{l-1}) Zl′=MSA(Zl−1′)
-
于是作者将整个过程定义为一个block B
-
剪枝后公式表示为:
-
可以发现这里多了个 m l m_l ml,他用来表示某个patch是否被剪掉,对应的剪枝后的Block B表示为:
Top-Down Pruning
-
作者在这里分析了为什么可以这么剪,我认为主要就三句话+两个问题
-
为什么Transformer不能想CNN一样剪?
-
The main reason is that patches in different layers of a vision transformer are one-to-one corresponding However, in the vision transformer (Figure 2(b)), different patches communicate with others by an attention map, which reflects the similarity between different patches.
-
如Figure 2所示,作者认为最主要的原因是ViT中的patch是一一对应的(反应不同patch之间的联系、相似度),而CNN中的每层layer则不是这样的,他是关注与某种特征,这就意味着CNN剪掉某一层的channel的意义和ViT剪掉patch是不同的
-
其实我认为,CNN剪掉channel应当对应ViT剪掉某个head,作者这种方法相当于CNN在Hx W维度上剪枝,更加细一点。
-
-
为什么要Top-Down Pruning?
-
-
上面这张图表示随着layer层数的增加,不同patch之间的相似度是上升的,这就意味着越深的layer存在越多的冗余,因此Top-Down Pruning更有机会能剪掉更多patch(越低剪掉的patch越少这样),对应作者的这句
-
It implies that more redundant patches can be safely removed in deeper layers, and fewer in shallower layers
-
-
Impact Estimation
- 这段偏理论,哪里讲错了记得踢我一下
- 在上文公式2中有一个参数 m l m_l ml,他用来表示某个patch是否被剪掉,每一层都有。那么这个问题就可以转化成一个优化问题
-
然而0范数的优化问题不是那么容易的( ∣ ∣ . ∣ ∣ 0 ||.||_0 ∣∣.∣∣0是count非0元素个数),作者的原文是说非凸、NP-Hard,要求Combinatorial Search,具体可以看看为什么不用L0范数做正则化?
-
Eq. 3 is hard to optimize directly, as it involve l0 optimization under constraint, which is non-convex, NP hard and requires combinatorial search [21]
-
-
于是作者想着用近似去替代这个优化问题,具体用到的是普希茨连续条件(Lipschitz continuity):
-
为简化起见,作者定义 P l h = s o f t m a x ( Q l K l l T d ) P^h_l = softmax(\frac{Q_l{K^l_l}^T}{\sqrt d}) Plh=softmax(dQlKllT),于是MSA可以简化为
-
这里将 W l o W ^o_l Wlo分成了H份, W l h o ∈ R d H × d W^{ho}_l \in R^{\frac{d}{H}×d} Wlho∈RHd×d,然后吧V拆成未经过FC之前的样子
-
然后有定义了 O ( . , W l ) O(.,W_l) O(.,Wl)为MSA中 W l W_l Wl矩阵乘积以及MLP中的模块,于是对应的Block B就可以表示为
-
于是对于第L层的feature定义为:
-
表示这一层的输入,用剪掉的Block计算完之后得到的结果,这样的话就可以得到某一层剪掉某个patch对网络输出的影响,也就是上文优化问题中误差的表示
-
然后就是将 F ^ \hat F F^拆开(拆一个B出来,然后B拆乘上文定义的O的那个公式),利用Lipschitz continuity做近似(吧O去掉)
-
接下来就是用Lipschitz一直拆一直近似,eq8表示一致拆到第t层,eq9是再拆一层,但是保留这一层的MSA部分,并拆开, I N I_N IN作者没说具体什么意思,根据公式有理由猜测是一个全1的矩阵(对应 d i a g ( m l ) diag(m_l) diag(ml))
-
然后对于eq10是做了一下公式简化,具体流程可以看这个笔记
-
随后作者定义 m l , i m_{l,i} ml,i为 m l m_l ml的第i个元素,eq10就可以进一步细分,然后经过放缩把 m l , i m_{l,i} ml,i!那块提到F范数外边
-
各元素的意思如下:
-
有了eq12,就可以对其做一个近似,把范数外边的常数去掉,即得到eq13,于是最终得到了某个patch的影响的评价分数
-
剪枝过程
最终效果
-
这是在ImageNet上的结果
-
Fig4是剪枝程度与精度的关系,Fig5是与普通剪枝的各层的Patch数量的对比
思考
- 作者的思想其实比较简单,但是推理过程比较繁琐,我也不太确定有没有理解错,本文其实有几个地方没有写清楚,比如某些参数的含义需要揣摩一下才知道啥意思,Exam中的有个uniform pruning,作者没有特别指出是哪种。
- 最后我还有个小疑问,就是如果最后m矩阵固定了,也就是某一个patch就永远没有用了,那么按照作者的说法,patch所对应的位置是固定的,比如下面这张,我在l+1层只保留这两个patch,那么万一物体的主要特征不在这两个patch呢,是不是效果就会差,或者是说这是针对某个特定数据集的剪枝?