大白话Pyramid Vision Transformer

本文转载自知乎,已获作者授权转载。

链接:https://zhuanlan.zhihu.com/p/353222035

TL;DR
这个工作把金字塔结构引入到Transformer[1]中,使得它可以像ResNet[2]那样无缝接入到各种下游任务中(如:物体检测,语义分割),同时也取得了非常不错的效果。

希望这些尝试能够促进更多下游任务的进一步发展,将NLP领域中Transformer的火把传递到CV的各个任务上。欢迎各位看官试用~


我们主要做了以下几个微小的工作:

1. 分析ViT[3]遗留的问题

大家都知道,在ViT中,作者为图像分类提出了一个纯Transformer的模型,迈出了非常重要的一步。相信大家看到ViT后的第一个想法就是:要是能替代掉恺明大佬的ResNet用在下游任务中岂不是美滋滋。


但是,ViT的结构是下面这样的,它和原版Transformer一样是柱状结构的 。


这就意味着:

1)它全程只能输出16-stride或者32-stride的feature map;

2)一旦输入图像的分辨率稍微大点,占用显存就会很高甚至显存溢出。

图来自ViT的论文[3]。

图来自网络,侵删

2. 引入金字塔结构

计算机视觉中CNN backbone经过多年的发展,沉淀了一些通用的设计模式。
最为典型的就是金字塔结构。

简单的概括就是:

1)feature map的分辨率随着网络加深,逐渐减小;

2)feature map的channel数随着网络加深,逐渐增大。

几乎所有的密集预测(dense prediction)算法都是围绕着特征金字塔设计的,比如SSD[4],Faster R-CNN[5], RetinaNet[6]。

这个结构怎么才能引入到Transformer里面呢?

图来自网络,侵删

试过一堆胡里花哨的做法之后,我们最终还是发现:简单地堆叠多个独立的Transformer encoder效果是最好的(奥卡姆剃刀定律,yes)。

然后我们就得到了PVT,如下图所示。在每个Stage中通过Patch Embedding来逐渐降低输入的分辨率。

其中,除了金字塔结构以外。为了可以以更小的代价处理高分辨率(4-stride或8-stride)的feature map,我们对Multi-Head Attention也做了一些调整。


为了在保证feature map分辨率和全局感受野的同时降低计算量,我们把key(K)和value(V)的长和宽分别缩小到以前的1/R_i。

通过这种方法,我们就可以以一个较小的代价处理4-stride,和8-stride的feature map了。

3. 应用到检测分割上

接下来,我们可以把PVT接入到检测和分割模型上试试水了。PVT替换ResNet非常容易,以大家常用的mmdetection[7]为例,放置好模型的代码文件之后,只需要修改模型的config文件就可以了。

从上面的图中可以看到,PVT在RetinaNet上的效果还是非常不错的。在和ResNet50相同的参数量下,PVT-S+RetinaNet在COCO val2017上的AP可以到40+。


另外我们还基于PVT+DETR[8]和Trans2Seg[9]构建了纯transformer的检测和分割网络,效果也不错。具体参考论文Section 5.4。

最后分享一下我个人的一些不成熟的看法吧。

1. 为什么PVT在同样参数量下比CNN效果好?


我认为有两点 :1)全局感受野和 2)动态权重。

其实本质上,Multi-Head Attention(MHA)和Conv有一些相通的地方。MHA可以大致看作是一个具备全局感受野的,且结果是按照attention weight加权平均的卷积。因此Transformer的特征表达能力会更强。

关于MHA和Conv之间的联系,更多有意思的见解可以拜读下代季峰大佬的An Empirical Study of Spatial Attention Mechanisms in Deep Networks[10]。


2. 后续可扩展的思路

1)效率更高的Attention:随着输入图片的增大,PVT消耗资源的增长率要比ResNet要高,所以PVT更适合处理中等输入分辨率的图片(具体见PVT的Ablation Study)。所以找到一种效率更高的Attention方案是很重要的。

2)Position Embedding:PVT的position embedding是和ViT一样,都是随机的参数,然后硬学的。而且在改变输入图像的分辨率的时候,position embedding还需要通过插值来调整大小。所以我觉得这也是可以改进的地方,找到一种更适合2D图像的方法。

3)金字塔结构:PVT只是一种较简单的金字塔式Tranformer。中间是通过Patch Embedding连接的,或许有更优美的方案。

最后的最后,PVT在分类/检测/分割或其他一些领域上的应用(code/config/model)都会在github.com/whai362/PVT上发布,大家可以star一下,方便后面查看。

论文:https://arxiv.org/abs/2102.12122

源码:https://github.com/whai362/PVT

参考资料

[1]^Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Proceedings of the 31st International Conference on Neural Information Processing Systems. 2017: 6000-6010.

[2]^He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.

[3]^Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.

[4]^Liu W, Anguelov D, Erhan D, et al. Ssd: Single shot multibox detector[C]//European conference on computer vision. Springer, Cham, 2016: 21-37.

[5]^Ren S, He K, Girshick R, et al. Faster R-CNN: towards real-time object detection with region proposal networks[C]//Proceedings of the 28th International Conference on Neural Information Processing Systems-Volume 1. 2015: 91-99.

[6]^Lin T Y, Goyal P, Girshick R, et al. Focal loss for dense object detection[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2980-2988.

[7]^Chen K, Wang J, Pang J, et al. MMDetection: Open mmlab detection toolbox and benchmark[J]. arXiv preprint arXiv:1906.07155, 2019.

[8]^Carion N, Massa F, Synnaeve G, et al. End-to-end object detection with transformers[C]//European Conference on Computer Vision. Springer, Cham, 2020: 213-229.

[9]^Xie E, Wang W, Wang W, et al. Segmenting transparent object in the wild with transformer[J]. arXiv preprint arXiv:2101.08461, 2021.

[10]^Zhu X, Cheng D, Zhang Z, et al. An empirical study of spatial attention mechanisms in deep networks[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019: 6688-6697.

END

备注:TFM

Transformer交流群

2D、3D目标检测等最新资讯,若已为CV君其他账号好友请直接私信。

我爱计算机视觉

微信号:aicvml

QQ群:805388940

微博知乎:@我爱计算机视觉

投稿:amos@52cv.net

网站:www.52cv.net

在看,让更多人看到  

  • 6
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值