一篇搞懂swin-transformer:Hierarchical Vision Transformer using Shifted Windows

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows:Swin Transformer:使用移位窗口的分层视觉转换器。作者大都是华人,论文由亚洲微软团队,中国科学技术大学,西安交通大学,清华大学写出。刊会是计算机视觉方面的三大顶级会议:ICCV,CVPR,ECCV,被评为2021年最佳论文的荣誉

备注:可以用于组会上给别人讲ppt😄

一、背景与意义

Transformer是为序列建模和转导任务而设计的,它以关注数据中的远程依赖关系建模而闻名。它在语言领域的巨大成功促使研究人员研究它对计算机视觉的适应性,特别是图像分类和联合视觉语言建模。在本文中,试图扩大Transformer的适用性,使其可以作为计算机视觉的通用主干

*远程序列关系就是说:第一个得到的结果会受到离它很远的元素影响

但是把语言领域的高性能转移到视觉领域有两个重大挑战。其中一个差异为规模:与语言转换器中用作处理基本元素的单词标记不同,视觉元素的规模可能会有很大变化,在现有的基于Transformer的模型中,token都是固定规模的,这一特性不适合这些视觉应用。另一个区别是,与文本段落中的单词相比,图像中的像素分辨率高得多。

为了克服这些问题,我们提出了一种通用的Transformer主干,称为Swin Transformer,它构建分层特征图,并对图像大小具有线性计算复杂性。线性计算复杂性是通过在划分图像(红色轮廓)的非重叠窗口内局部计算自注意来实现的。每个窗口中的补丁数量是固定的,因此复杂性与图像大小成线性关系。本质思想就是注意力机制结合CNN的分层。

二、网络框架

通过网络框架能看到,图片大小在减少,但是图片的深度在上升,接下来为每层进行讲解:

1、Patch partition和Linear Embedding--理解为对图片进行大小通道的处理

Patch partition:可以用一个4*4、步长为4的卷积核来做卷积,通道数48

Linear Embedding:用1*1卷积核,步长为1的卷积核来做卷积,通道数96

2、W-MSA

W-MSA是在MSA(多头自注意力机制)的基础上进行改进。先简单介绍一下MSA,主要就是对下面图片公式的一个理解。

W都是经过学习得到的,输出b的向量长度和输入的a长度是一样的,Q和K相乘得到一个结果是矩阵,矩阵做softmax激活函数。除以根号d是为了什么:根号d表示的是向量k的长度,会使得分布重新归一化到数据原分布,防止出现梯度消失,除以之后到达的位置可能在(0-1)之间,Softmax使用在多分类的情况下的结果在0-1之间。

接下来就是W-MSA层做的事情:

原先的MSA是每个patch都与和其他的patch有关,因为计算QKV的时候都有相互联系。而W-MSA是先把特征矩阵分成很多windows(窗口),论文里面使用的7*7的窗口大小来分,然后每个window内做MSA。

优:目的是减少计算量; 缺:各个window之间不能通信,且导致感受野变小

使得计算量减少:

对于MSA:首先计算Q矩阵对应的计算量为HW×C ×C再加上还有KV,故计算量为3 HW×C ×C ,然后就算QK相乘的计算量为 HW× HW×C ,假设忽略SoftMax等计算得到的结果设为Z,然后Z和V相乘的计算量为HW× HW× C ,最后因为多头再进行融合计算量为HW× HW× C 。

对于W-MSA:我们在W-MSA中把特征图划分了各个window,窗口的宽高为M,所以原图有h/M× w/M窗口。所以先把M带入到MSA中的宽高,得到的结果再乘以h/M× w/M。

3、SW-MSA

为了解决上面W-MSA的缺点,所以才有了这一步:首先从layer1变为layer1+1,变完之后各个window中大小不为4*4那么如何做注意力机制,如果填充就会加大计算。论文做法就是把所有块进行移动变成4*4。

移动完之后它们原先的相对位置改变了,并且本身不是同一个window,论文是对于移动完生成4*4的块照常做attention,但是对于不是同一个区域的QKV得到的相关性设置为0,经过以上所有计算,我们要把移动完的块再次移动到初始的位置。

4、W-MSA和SW-MSA公式

W-MSA比MSA在公式上多加上了矩阵B,B的意思是相对位置偏置,其本质就是希望attention map进一步有所偏重。因为attention map中某个值越低,经过softmax之后,该值会更低,对最终特征的贡献就低。B的计算是通过深度网络学习学到的,但是前面的qk计算得到是一个矩阵,矩阵中不同的位置要对应不同的参数B,而B学习到的参数保存在一个一维数组中,所以要实现一维数组和二维数组之间的对应关系,下图案例就是实现B的一维数组和二维数组对应关系。假设qk计算的矩阵大小为2*2.先得到矩阵所有像素的相对位置索引,而不是相对位置偏置参数,我们要根据相对位置索引找到相对位置偏置参数,而相对位置偏置参数是训练出来的,保存在relative position bias table表里的,如果表是二维,则可以直接一对一找到相应的参数,如果表是一维的,则表的长度为2M-1 ×2M-1,两种都能实现不过论文采用一维。

经过这些操作就能实现一维数组和二维数组的一一对应。

4、MLP

5、Patch Merging

通过下采样,那么矩阵大小变小,但是深度变深,当进行下一次的划分窗口的时候窗口的数量会变少。

三、实验结果

四、总结

 Swin Transformer是一种新的视觉Transformer ,它产生层次特征表示,并且对输入图像大小具有线性计算复杂度。Swin Transformer在COCO对象检测和ADE20K语义分割方面实现了最先进的性能,大大超过了以前的最佳方法。我们希望Swin Transformer在各种视觉问题上的强大表现将鼓励视觉和语言信号的统一建模。

### Swin Transformer 的复现教程 #### 1. 模型概述 Swin Transformer 是一种分层视觉变换器 (Hierarchical Vision Transformer),它通过滑动窗口机制构建局部表示并支持跨窗口连接[^1]。该模型的核心组件包括分层设计、移位窗口 (Shifted Window) 和自注意力机制。 #### 2. 数据预处理 数据预处理阶段涉及将输入图像划分为多个 patch,并将其映射到 token 序列中。具体过程如下: - 将输入图像切分成大小为 \(P \times P\) 的 patches。 - 对每个 patch 进行线性嵌入操作,得到初始的 token 表示。 - 使用卷积下采样层进一步减少空间分辨率,形成多尺度特征图。 此阶段通常称为“阶段 1”,其中 transformer 块的数量为 \(H/4 \times W/4\),即每张图片被分解成若干 tokens[^2]。 #### 3. 移位窗口机制 为了提高效率和建模能力,Swin Transformer 引入了移位窗口策略。在标准窗口划分的基础上,每隔一层会调整窗口的位置以引入交叉窗口的信息交互。这种方法显著提升了性能,在 ImageNet-1K 图像分类任务上 top-1 准确率提高了 +1.1%,而在 COCO 目标检测任务中则分别提升 +2.8 box AP 和 +2.2 mask AP[^4]。 #### 4. PyTorch 实现代码 以下是基于 PyTorch 的 Swin Transformer 核心模块实现: ```python import torch from torch import nn class PatchEmbed(nn.Module): """Patch Embedding Layer""" def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Mlp(nn.Module): """Multilayer Perceptron""" def __init__(self, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x class WindowAttention(nn.Module): """Window-based Multi-head Self Attention (MSA) module with relative position bias.""" def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(0.) self.proj = nn.Linear(dim, dim) def forward(self, x): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block""" def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0): super().__init__() self.input_resolution = input_resolution self.window_size = window_size self.shift_size = shift_size if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads ) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * 4) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = shifted_x.unfold(1, self.window_size, self.window_size)\ .unfold(2, self.window_size, self.window_size) x_windows = x_windows.contiguous().view(-1, self.window_size*self.window_size, C) # attention and projection attn_windows = self.attn(x_windows) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # reverse windows shifted_x = attn_windows.permute(0, 1, 2, 3).contiguous().view(B, H, W, C) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H*W, C) # FFN x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage.""" def __init__(self, dim, depth, num_heads, window_size=7): super().__init__() self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, input_resolution=(window_size, window_size), num_heads=num_heads, window_size=window_size, shift_size=0 if i % 2 == 0 else window_size // 2 ) for i in range(depth)]) def forward(self, x): for blk in self.blocks: x = blk(x) return x class SwinTransformer(nn.Module): """Overall architecture of the Swin Transformer model.""" def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7): super().__init__() self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] self.layers = nn.ModuleList() for i_layer in range(len(depths)): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size ) self.layers.append(layer) self.norm = nn.LayerNorm(int(embed_dim * 2 ** (len(depths)-1))) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(int(embed_dim * 2 ** (len(depths)-1)), num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x) for layer in self.layers: x = layer(x) x = self.norm(x.mean(1)) x = self.head(x) return x ``` #### 5. 训练与验证流程 训练过程中可以使用常见的优化算法(如 Adam 或 SGD),并通过学习率调度器动态调整超参数。对于下游任务(如目标检测或语义分割),可以通过微调预训练权重来加速收敛。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值