Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
(2021.8.17)
Abstract:分类:ImageNet-1K上top-1 acc达87.3;检测&分割:COCO 上达58.7 box AP和51.1 mask AP;语义分割:ADE20K val上达53.5 mIoU)
一、Introduction
如下图所示,Swin Transformer利用如FPN1或U-Net2等用于密集预测的技术,逐步在深层中合并图像块来构建分层特征图,并且由于在数量固定的每个局部窗口(红色)中计算自注意力,使得其与输入图像大小计算复杂度呈线性。相比之下ViT生成单个低分辨率的特征图,并且由于全局自注意力计算,其计算复杂度相对输入图像大小呈二次方。
Swin-Transformer的一个关键设计元素是它在连续的自注意层之间移动窗口分区,显著增强建模能力(见论文表4)。如下图所示,在layerl l中,规则划分窗口,并在每个窗口内计算自注意。在layerl l+1中,窗口移动产生新的窗口,建立跨窗口连接。一个窗口内的query patches共享key set(query&key),有助于硬件的内存访问。而早期的基于滑动窗口的自注意方法3非共享,Swin延迟更低,建模能力相似(见论文表5和表6)。移动窗口方法被证明对所有MLP架构都是有益的4。
二、Method
1.总体架构
下图,(a).Swin Transformer (Swin-T)的tiny设计结构;(b). 常规窗口和shifted windowing(移动窗口)的多头自注意力模块。
Swin-T同ViT,输入为分割patch(图块),每个图块都被视为一个“token”,其特征被置为原始像素RGB值的串联,即特征维度为4×4×3=48。线性embedding层将其映射到任意维度(图中
C
C
C)。
Swin Transformer block由一个基于移位窗口的MSA模块组成,后是一个中间有GELU的2层MLP。在每个MSA模块和每个MLP之前应用LayerNorm(LN)层,并在每个模块之后应用残差连接。
2.位置窗口的自注意力
(1).Efficient batch computation for shifted configuration(针对移位配置的高效批量计算)
移位窗口分区产生更多的窗口,由左上方进行循环位移,重新划分窗口与常规一致,如下图所示。但一个批处理窗口可能由几个不在特征图附近的子窗口组成,因此采用了mask机制来将自注意计算限制在每个子窗口内。
注:为了使窗口大小(M,M)可被特征图大小(h,w)整除,如果需要,在特征图上进行右下填充。
(2).Relative position bias(相对位置偏差)
在计算自我注意时,遵循5,包括相对位置偏差。实验(论文表4)发现不使用偏差项或使用绝对位置embedding时精度降低。如6所述,进一步将绝对位置embedding添加到输入中不利于模型表现。预训练中学习的相对位置偏差也可以用于初始化模型,以便用双立方插值算法对不同的窗口大小进行微调7。
(3).体系结构变体
Swin-B为基本模型,其模型大小和计算复杂性类似ViT-B/DeiT-B。Swin-T、Swin-S和Swin-L,模型大小和计算复杂度分别约为0.25、0.5和2。Swin-T和Swin-S的复杂性分别与ResNet-50(DeiT-S)和ResNet-101的复杂性相似。默认窗口大小为M=7。模型变体的参数:
实验中,每个head的query维度d=32,每个MLP的扩展层α=4。其中C是第一阶段中隐藏层的通道数。ImageNet图像分类的模型大小、理论计算复杂度(FLOPs)和模型变体的吞吐量见论文表1。
三、Experiments
1.Image Classification on ImageNet-1K
Settings:
-
常规ImageNet-1K训练,主要遵循8。epochs(300),batch size(1024),learning rate(0.001),weight decay(0.05),optimizer(AdamW9),scheduler(cosine decay),warm up(20;linear)。训练中包含其8中的大多数增强和正则化策略,除了不会提高性能的repeated augmentation(重复增强)10和EMA11。但重复增强对稳定ViT的训练至关重要。
-
ImageNet-22K上的预训练和ImageNet-1K上的微调。预训练:epochs(90),batch size(4096),learning rate(0.001),weight decay(0.01),optimizer(AdamW),warm up(5;linear)。微调:epochs(30),batch size(1024),learning rate(0.001),learning rate( 1 0 − 5 10^{-5} 10−5;constant),weight decay( 1 0 − 8 10^{-8} 10−8)。
2.Object Detection on COCO
四个典型的目标检测框架:Cascade Mask R-CNN12,ATSS13,RepPoints v214,mmdetection中的Sparse RCNN15。
Settings:
- epochs(3x schedule;36),batch size(16),learning rate(0.0001),weight decay(0.05),optimizer(AdamW),多尺度训练16(输入resize,短边在480到800之间,长边最长为1333)。对于系统级比较,采用一种改进的HTC17(HTC++),带有instaboost18、更强的多尺度训练19、6x schedule(72 epochs)、soft-NMS20和ImageNet-22K pre-trained作为初始化模型。
3.Semantic Segmentation on ADE20K
ps:下为论文内文献索引,供拓展使用。