论文阅读《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》

SwinTransformer是微软提出的一种新型Transformer架构,用于解决视觉任务中的尺度问题和高分辨率需求。它引入了层次化的窗口机制,通过shiftedwindow操作实现了线性时间复杂度,适用于多尺度特征提取。在分类、检测和分割任务中表现出色,展现出高效且准确的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

论文链接:https://arxiv.org/abs/2103.14030

代码地址: https://github.com/microsoft/Swin-Transformer

开源15天,star超3k的作品。作者提出了一个基于transformer的backbone,可用于多种视觉任务。和以往ViT,DETR等结构不同的是,Swin Transformer通过shifted windows操作,实现了CNN里面的hierarchical的结构。这类hierarchical的结构更适用于不同的scale,同时其计算复杂性仅与image size线性相关。实验证明,以Swin Transformer为backbone的模型,在分类、检测和分割等多个任务上实现了霸榜。

Motivation

将NLP领域的Transformer迁移到CV的task上,需要考虑这两个模态之间的不同:(1)scale问题:像object detection,目标的尺度不一样,而现有的Transformer里面的token大多对应一个固定的尺度,因此并不适合这类task;(2)图像任务需要更高的resolution:以semantic segmentation为例,对现存的Transformer来说,想要得到更高分辨率的结果,意味着要承担image size的平方的时间复杂度,代价显然非常高。

在这里插入图片描述

基于以上两点考量,作者提出了Swin Transformer以解决上述问题:(1)首先该backbone的特征图构建是hierarchical的,因此可以适应多尺度。通过不断的融合临近的patch,即可得到hierarchical的特征图,类似于ResNet,后面可以接FPN或U-Net进行多尺度密集预测;(2)然后他的计算复杂性是和image size线性相关的,相比于现存方法的平方相关优化了很多。用到了non-overlapping window,即一个window包含若干个patch,然后所有的self-attention均在window内部locally计算。由于window内patch数量是固定的,所以时间复杂度与image size成线性关系。

由于上述的window是non-overlapping的,如果只对每个window独立地做self-attention,并不能有效的融合周围的特征,不适合提取到hierarchical的feature map(对比CNN里的kernal在整个feature map上进行滑动)。因此Swin Transformer中的一个核心结构shifted window,用来划分相邻self-attention层的window。

Architecture

在这里插入图片描述
Swin Transformer总体结构如上。输入是RGB图像,首先经过一个Patch Partition模块,这里以 4 ∗ 4 4*4 44大小为一个patch,所以划分后维度变成 H 4 ∗ W 4 ∗ 48 \frac{H}{4}*\frac{W}{4}*48 4H4W48。然后经过一个Linear Embedding层,可以将特征嵌入到任意维度,这里记作 C C C。随后经过一个核心的Swin Transformer Block模块,token数不变。以上属于Stage1。

为了产生一个hierarchical的特征表示,token的数量随着网络的加深应该不断减少,因此这里采用Patch Merging模块,来对 2 ∗ 2 2*2 22区域内的patch进行融合,这样每一个新patch特征维度变成 4 C 4C 4C。为了减少计算量,融合后紧跟一个降维操作,将维度降到 2 C 2C 2C。随后经过Swin Transformer Block,维度保持不变。以上为Stage2。

后续Stage3-5和之前是一样的,不断融合相邻的patch,直到Stage5输出的特征图大小为 H 32 ∗ W 32 ∗ 8 C \frac{H}{32}*\frac{W}{32}*8C 32H32W8C。这五个Stage联合起来产生了hierarchical的特征表示。和VGG,ResNet这些网络是一样的,都产生了多层级的feature map。因此Swin Transformer这个backbone相比于其他的视觉Transformer,更适合用在CV的task上面。

其中的核心模块Swin Transformer Block如Figure3(b)所示,关键之处使用W-MSA和SW-MSA取代了传统Transformer中的MSA(Multi-head self attention module),其他基本不变。MLP,LayerNorm和shortcut结构都用在了这个Block中,不过不是重点。重点是SW-MSA用到了shifted window操作,使得hierarchical feature和线性时间复杂度成为可能。

Shifted Window based Self-Attention

在这里插入图片描述
关于window与patch的关系,上图表示的很清楚。首先分析一下以前的global self-attention,和本文采用在每个window内部做self-attention的时间复杂度。两种时间复杂度如下式:
在这里插入图片描述

两个公式的由来是这样的:
(1) Q = x ∗ W Q Q=x*W^{Q} Q=xWQ,需要 h w C 2 hwC^{2} hwC2。K和V的复杂度与Q一样。随后是 Q K T QK^{T} QKT需要 ( h w ) 2 C (hw)^{2}C (hw)2C,其结果与V相乘同样需要 ( h w ) 2 C (hw)^{2}C (hw)2C。得到结果 Z Z Z以后乘 W Z W^{Z} WZ需要 h w C 2 hwC^{2} hwC2。因此总计 4 h w C 2 + 2 ( h w ) 2 C 4hwC^{2}+2(hw)^{2}C 4hwC2+2(hw)2C。这个结果显然是和image size平方相关的。
(2)如采用window内部self attention,KQVZ的复杂度都不变。 Q K T QK^{T} QKT需要计算 h M ∗ w M \frac{h}{M}*\frac{w}{M} MhMw个window,每个window的复杂度为 ( M 2 ) 2 C (M^{2})^{2}C (M2)2C,因此总计 M 2 h w C M^{2}hwC M2hwC。后续结果与V相乘的计算量也为 M 2 h w C M^{2}hwC M2hwC。所以总体的时间复杂度为 4 h w C 2 + 2 M 2 h w C 4hwC^{2}+2M^{2}hwC 4hwC2+2M2hwC。由于M是固定的,因此该结果与image size线性相关。

通过上述分析,Swin Trabsformer已经解决了现存方法时间复杂度高的缺点。但如果采用只计算window内部的self-attention,会导致不同window之间缺少connection,换句话说,global信息将会损失。为了解决这个问题,作者提出了Shifted window,用以划分相邻layer特征图上的window。

一种很直观的方法就是按照上述Figure2右侧图的划分方式,这样在之前层的那些non-overlapping的window就可以建立联系。但是Figure2的Shifted window方式也有弊端,window数目从 2 ∗ 2 2*2 22变成了 3 ∗ 3 3*3 33,计算复杂度直接上涨2.25倍,并不是特别乐观。因此作者提出了一种shifted方式,如下图:
在这里插入图片描述
如上图,依然按照Figure2的方式划分网格,但是采用了cyclic shift(左二):将浅色区域的A B C平移到深色区域的对应位置。在做了这个shift操作以后,在原来的feature map上一些不相邻的window会变成相邻的,且位于同一个新的window中,因此需要采用一个mask机制,来抑制这两个本不相邻的部分在一个window中做self attention。这一部分最好结合代码,可参考该文章

最后是在做self attention的时候,加上一个position bias,如下:
在这里插入图片描述

Experiment

分类任务惜败于EfficientNet,可能是因为多尺度和分类关系不是那么密切?
在这里插入图片描述
检测精度还是很高的。
在这里插入图片描述
语义分割最大的网络精度也很高。
在这里插入图片描述

总结

作者提出了一个Transformer backbone,不仅可以产生hierarchical的特征表示,还可以使时间复杂度降至和image size线性相关。核心部分就是window的引入降低了复杂度,patch的融合以及shifted window的引入可以提取多尺度的feature。Swin Transformer在分类、检测、分割等多个任务中达到了很高的精度,同时拥有不俗的速度。虽然不认同是对CNN的降维打击,但这种融合了CNN层级思想的Transfomer架构还是非常值得学习和follow的。

### Swin Transformer 的复现教程 #### 1. 模型概述 Swin Transformer 是一种分层视觉变换器 (Hierarchical Vision Transformer),它通过滑动窗口机制构建局部表示并支持跨窗口连接[^1]。该模型的核心组件包括分层设计、移位窗口 (Shifted Window) 和自注意力机制。 #### 2. 数据预处理 数据预处理阶段涉及将输入图像划分为多个 patch,并将其映射到 token 序列中。具体过程如下: - 将输入图像切分成大小为 \(P \times P\) 的 patches。 - 对每个 patch 进行线性嵌入操作,得到初始的 token 表示。 - 使用卷积下采样层进一步减少空间分辨率,形成多尺度特征图。 此阶段通常称为“阶段 1”,其中 transformer 块的数量为 \(H/4 \times W/4\),即每张图片被分解成若干 tokens[^2]。 #### 3. 移位窗口机制 为了提高效率和建模能力,Swin Transformer 引入了移位窗口策略。在标准窗口划分的基础上,每隔一层会调整窗口的位置以引入交叉窗口的信息交互。这种方法显著提升了性能,在 ImageNet-1K 图像分类任务上 top-1 准确率提高了 +1.1%,而在 COCO 目标检测任务中则分别提升 +2.8 box AP 和 +2.2 mask AP[^4]。 #### 4. PyTorch 实现代码 以下是基于 PyTorch 的 Swin Transformer 核心模块实现: ```python import torch from torch import nn class PatchEmbed(nn.Module): """Patch Embedding Layer""" def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Mlp(nn.Module): """Multilayer Perceptron""" def __init__(self, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x class WindowAttention(nn.Module): """Window-based Multi-head Self Attention (MSA) module with relative position bias.""" def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(0.) self.proj = nn.Linear(dim, dim) def forward(self, x): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block""" def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0): super().__init__() self.input_resolution = input_resolution self.window_size = window_size self.shift_size = shift_size if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads ) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * 4) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = shifted_x.unfold(1, self.window_size, self.window_size)\ .unfold(2, self.window_size, self.window_size) x_windows = x_windows.contiguous().view(-1, self.window_size*self.window_size, C) # attention and projection attn_windows = self.attn(x_windows) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # reverse windows shifted_x = attn_windows.permute(0, 1, 2, 3).contiguous().view(B, H, W, C) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H*W, C) # FFN x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage.""" def __init__(self, dim, depth, num_heads, window_size=7): super().__init__() self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, input_resolution=(window_size, window_size), num_heads=num_heads, window_size=window_size, shift_size=0 if i % 2 == 0 else window_size // 2 ) for i in range(depth)]) def forward(self, x): for blk in self.blocks: x = blk(x) return x class SwinTransformer(nn.Module): """Overall architecture of the Swin Transformer model.""" def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7): super().__init__() self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] self.layers = nn.ModuleList() for i_layer in range(len(depths)): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size ) self.layers.append(layer) self.norm = nn.LayerNorm(int(embed_dim * 2 ** (len(depths)-1))) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(int(embed_dim * 2 ** (len(depths)-1)), num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x) for layer in self.layers: x = layer(x) x = self.norm(x.mean(1)) x = self.head(x) return x ``` #### 5. 训练与验证流程 训练过程中可以使用常见的优化算法(如 Adam 或 SGD),并通过学习率调度器动态调整超参数。对于下游任务(如目标检测或语义分割),可以通过微调预训练权重来加速收敛。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值