Vision Transformer Pruning简记

Vision Transformer Pruning简记

参考
剪枝流程
  • 稀疏正则训练
  • 剪枝,减去不重要的部分
  • finetune微调
剪什么?
  • 有关于稀疏训练虽然重要,但是首要还是确定剪什么,在Vision Transformer Pruning中作者剪枝的是Dimension,那么什么是Dimension呢?
    • 我的理解是:Dimension的长度其实就是FC层的输入长度
    • image-20211226153700049
怎么剪?
  • 首先回顾一下Transformer
回顾Transformer
  • 我的理解是:以Vision Transformer为例(当然作者用的好像是DeiT-base),我们首先将图片分成16x16个patch,然后进行编码embedding,会出来1xd的向量,对每个patch都这么做然后concat一下就得到了Attention部分的输入, X ∈ R n × d X \in R^{n×d} XRn×d,这里的n就是我们分出来的patch数量

  • 上文提到了输入 X ∈ R n × d X \in R^{n×d} XRn×d,然后一般我的理解是各✖️ W k 、 W q 、 W v W^k、 W^q 、W^v WkWqWv权重矩阵得到KQV,然后进入Attention计算(李宏毅老师的说法),然而作者得到KQV的方法好像是FC,这点我查证了一下VIT好像在to_kqv的时候是这么做的,这是没看VIT论文的锅(确信)

  • A t t e n t i o n ( Q K V ) = S o f t m a x ( Q K T d ) V Attention(QKV) = Softmax(\frac{QK^T}{\sqrt d})V Attention(QKV)=Softmax(d QKT)V

  • 对输出进行进行处理:Layer Norm+Residual

  • Y = X + F C o u t ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) Y = X + FC_{out}(Attention(FC_q(X),FC_v(X),FC_v(X)) Y=X+FCout(Attention(FCq(X),FCv(X),FCv(X))

  • 接下来就是MLP

  • Z = Y + F C 2 ( F C 1 ( Y ) ) Z = Y + FC_2(FC_1(Y)) Z=Y+FC2(FC1(Y))

那么剪哪里?
  • 作者给出了一张图,其中右侧是一个Transformer模块,会发现有多个Dimension Pruning,这里我放一张Transformer的encoder做对比就清楚了
  • 在这里插入图片描述

在这里插入图片描述

  • 我么可以知道两个跳跃链接分别对应encoder中的残差边,于是知道到在MSA(Multi Self-Attention)处有两个Dimension Pruning,MLP处有两个Dimension Pruning,那么这些分别是啥呢,这时候可以参考一下VIT的实现:VIT

    • 先看MLP部分:

    • class FeedForward(nn.Module):
          def __init__(self, dim, hidden_dim, dropout = 0.):
              super().__init__()
              self.net = nn.Sequential(
                  nn.Linear(dim, hidden_dim),
                  nn.GELU(),
                  nn.Dropout(dropout),
                  nn.Linear(hidden_dim, dim),
                  nn.Dropout(dropout)
              )
          def forward(self, x):
              return self.net(x)
      
    • 会发现刚好有两个nn.Linear,对应图中MLP部分的两个Linear(就是FC),然后我么就可以合理推测,所谓Dimension Pruning其实就是减少Linear的输入数量,即减少FC的输入参数

    • 然后回过头来看MSA部分:

    • class Attention(nn.Module):
          def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
              super().__init__()
              inner_dim = dim_head *  heads
              project_out = not (heads == 1 and dim_head == dim)
      
              self.heads = heads
              self.scale = dim_head ** -0.5
      
              self.attend = nn.Softmax(dim = -1)
              self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
      
              self.to_out = nn.Sequential(
                  nn.Linear(inner_dim, dim),
                  nn.Dropout(dropout)
              ) if project_out else nn.Identity()
      
          def forward(self, x):
              qkv = self.to_qkv(x).chunk(3, dim = -1)
              q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
      
              dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
      
              attn = self.attend(dots)
      
              out = torch.matmul(attn, v)
              out = rearrange(out, 'b h n d -> b n (h d)')
              return self.to_out(out)
      
    • 在init部分中也刚好有两个nn.Linear,这也印证了上文的猜想

那么怎么剪?
  • 从这张图其实就很好看出来

    • 对于FC的输入,作者做了一个Gate,对于那些小于阈值的节点,直接设为0,接下来是怎么得到这些分数,这是剪枝中的核心问题——如何判定某个节点的重要程度。
  • 作者定义如下:

    • 设分数 α ∗ ∈ { 0 , 1 } d \alpha^* \in \{0,1\}^d α{0,1}d
    • 一个节点是否剪枝可以表示为: X ∗ = X d i a g ( α ∗ ) X^*=Xdiag(\alpha^*) X=Xdiag(α)
  • 但是这样的离散值无法优化,所以作者将分数松弛了一下,变成了一个连续的数,这样就可以随着梯度下降进行优化:

    • α ^ ∈ R d \hat \alpha \in R^d α^Rd
    • 于是X就可以变成: X ^ = X d i a g ( α ^ ) \hat X = Xdiag(\hat \alpha) X^=Xdiag(α^)
    • 然后作者设置一个阈值: α ∗ = α ^ ≥ ζ \alpha^* = \hat \alpha \geq \zeta α=α^ζ,这样就可以实现评分然后根据评分来剪枝的效果
  • 于是剪枝公式可以表示为

  • X ∗ = P r u n e ( X ) X^* = Prune(X) X=Prune(X)

  • 结合上文说到多处剪枝,Transformer的公式可以变为

    • Q , K , V = F C q ′ ( P r u n e ( X ) ) , F C k ′ ( P r u n e ( X ) ) , F C v ′ ( P r u n e ( X ) ) Q,K,V = FC^{'}_q(Prune(X)),FC^{'}_k(Prune(X)),FC^{'}_v(Prune(X)) Q,K,V=FCq(Prune(X)),FCk(Prune(X)),FCv(Prune(X))
    • Y = X + F C o u t ′ ( P r u n e ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) ) Y = X + FC^{'}_{out}(Prune(Attention(FC_q(X),FC_v(X),FC_v(X))) Y=X+FCout(Prune(Attention(FCq(X),FCv(X),FCv(X)))
    • Z = Y + F C 2 ′ ( P r u n e ( F C 1 ′ ( P r u n e ( Y ) ) ) ) Z = Y + FC^{'}_2({Prune(FC^{'}_1(Prune(Y)))}) Z=Y+FC2(Prune(FC1(Prune(Y))))
  • 同时公式也侧面说明了剪枝剪哪里

实验部分
  • 这里就不细说了,总之就是效果还行
  • image-20211226161537384
作者的总结
  • 作者在文章结尾预测了MSA的M(即head的数量)也可以剪,这和我的预期是相符的,不过目前暂时没有看到这方面的工作,有看到的可以踢我一下。
思考
  • 对于这篇我认为实质上是将FC的剪枝方法用到Transformer中,不可否认效果还成,未来还能剪哪里呢?
  • 关于这点:
    • 首先head的数量确实可以
    • Patch Slimming for Efficient Vision Transformers这篇中提出剪枝Patch想法很新奇
    • 然后我觉得可以参考EfficientNet的思想(多个维度考虑模型大小,这里参考对网络深度的思考)是否可以讨论剪掉整个Endoer,或者只剪掉MSA和MLP保留其中一部分这样?
  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

椰子奶糖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值