ECCV 2022 | 在视觉Transformer上进行递归,不增参数,计算量还少

视学算法专栏

机器之心编辑部

今天跟大家分享一篇来自CMU等机构的论文《Sliced Recursive Transformer》,该论文已被 ECCV 2022 接收。

目前 vision transformer 在不同视觉任务上如分类、检测等都展示出了强大的性能,但是其巨大的参数量和计算量阻碍了该模型进一步在实际场景中的应用。基于这个考虑,本文重点研究了如何在不增加额外参数量的前提下把模型的表达能力挖掘到极致,同时还要保证模型计算量在合理范围内,从而可以在一些存储容量小,计算能力弱的嵌入式设备上部署。

基于这个动机,Zhiqiang Shen、邢波等研究者提出了一个 SReT 模型,通过循环递归结构来强化每个 block 的特征表达能力,同时又提出使用多个局部 group self-attention 来近似 vanilla global self-attention,在显著降低计算量 FLOPs 的同时,模型没有精度的损失。

cf99d32c5d42e1f54ca03d9a3f68714f.png

  • 论文地址:https://arxiv.org/pdf/2111.05297.pdf

  • 代码和模型:https://github.com/szq0214/SReT

总结而言,本文主要有以下两个创新点:

  1. 使用类似 RNN 里面的递归结构(recursive block)来构建 ViT 主体,参数量不涨的前提下提升模型表达能力;

  2. 使用 CNN 中 group-conv 类似的 group self-attention 来降低 FLOPs 的同时保持模型的高精度;

此外,本文还有其他一些小的改动:

  1. 网络最前面使用三层连续卷积,卷积核为 3x3,结构直接使用了研究者之前 DSOD 里面的 stem 结构;

  2. Knowledge distillation 只使用了单独的 soft label,而不是 DeiT 里面 hard 形式的 label 加 one-hot ground-truth,因为研究者认为 soft label 包含的信息更多,更有利于知识蒸馏;

  3. 使用可学习的 residual connection 来提升模型表达能力;

如下图所示,本文所提出的模型在参数量(Params)和计算量(FLOPs)方面相比其他模型都有明显的优势:

287ddd0bd4f82503bb70f497ff6bfe03.png

下面我们来解读这篇文章:

1.ViT 中的递归模块

递归操作的基本组成模块如下图:

a89d4777baabdbea2f73a1ca2c1ecafd.png

该模块非常简单明了,类似于 RNN 结构,将模块当前 step 的输出作为下个 step 的输入重新输进该模块,从而增强模型特征表达能力。

研究者展示了将该设计直接应用在 DeiT 上的结果,如下所示:

68f4ca5713103337ba4d4ee9f6e1cc21.png

可以看到在加入额外一次简单递归操作之后就可以得到将近 2% 的精度提升。

当然具体到全局网络结构层面还有不同的递归构建方法,如下图:

402d9d14203b76a47090b3a1c9f992ae.png

其中 NLL 层(Non-linear Projection Layer)是用来保证每个递归模块输入输出不完全一致。论文提出使用这个模块的主要原因是发现在上述 Table 1 里面更多次数的递归操作并没有进一步提升性能,说明网络可能学到了一个比较简单的状态,而 NLL 层可以强制模型输入输出不一致从而缓解这种情况。同时,研究者从实验结果发现上图 (1) internal loop 相比 external loop 设计拥有更好的 accuracy-FLOPs 结果。

2. 分组的 Group Self-attention 模块

如下图所示,研究者提出了一种分组的 group self-attention 策略来降低模型的 FLOPs,同时保证 self-attention 的全局注意力,从而使得模型没有明显精度损失:

08d65f7286e772cd908dfee395a7d04e.png

Group Self-attention 模块具体形式如下:

9d3c12742fd0c8f00ace1531d1a514cc.png

Group self-attention 的缺点是只有局部区域会相互作用,研究者提出通过使用 Permutation 操作来近似全局 self-attention 的机制,同时通过 Inverse Permutation 来复原和保留 tokens 的次序信息,针对这个部分的消融实验如下所示:

31643857a23e8c8b0631f79c7738bd66.png

其中 P 表示加入 Permutation,I 表示加入 Inverse  Permutation,-L 表示如果 group 数为 1,就不使用 P 和 I(比如模型最后一个 stage)。根据上述表格的结果,研究者最后采用了 [8, 2][4,1][1,1] 这种分组设计。

3. 其他设计

可学习的残差结构 (LRC):

1a30da8d985da5f9e050dd5e032bf9fc.png

研究者尝试了上图三种结构,图(3)结果最佳。具体而言,研究者在每个模块里面添加了 6 个额外参数(4+2,2 个在 NLL 层),这些参数会跟模型其他参数一起学习,从而使网络拥有更强的表达能力,参数初始化都为 1,在训练过程 6 个参数的数值变化情况如下所示:

78525a624811fad875d79828becdb36a.png

Stem 结构组成:

08733a7f847fc4d2db24f91bbe3b011e.png

如上表所示,Stem 由三个 3x3 的连续卷积组成,每个卷积 stride 为 2。

整体网络结构:

研究者进一步去掉了 class token 和 distillation token,并且发现精度有少量提升。

e900b757e623da72550555ee02546a7c.png

消融实验:

683e15c8ca9180af4b6be0f4d4c2c655.png

模型混合深度训练:

研究者进一步发现分组递归设计还有一个好处就是:可以支持模型混合深度训练,这种训练方式可以大大降低深度网络结构优化复杂度,研究者展示了 108 层不同模型结构优化过程的 landscape 可视化,如下图所示,可以很明显的看到混合深度结构优化过程困难程度显著低于另外两种结构。

88b68a3b492beb16d76cca3d10d8c0a5.png

最后,分组 group self-attention 算法 PyTorch 伪代码如下:

be0adae408e2c232ee57401b34ad721d.png

更多方法和实验细节可以阅读原论文和 GitHub 代码。

© THE END 

转载请联系原公众号获得授权

4b195775ec42bd7d396df435b0df1020.png

outside_default.png

点个在看 paper不断!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值