### Swin Transformer 的复现教程
#### 1. 模型概述
Swin Transformer 是一种分层视觉变换器 (Hierarchical Vision Transformer),它通过滑动窗口机制构建局部表示并支持跨窗口连接[^1]。该模型的核心组件包括分层设计、移位窗口 (Shifted Window) 和自注意力机制。
#### 2. 数据预处理
数据预处理阶段涉及将输入图像划分为多个 patch,并将其映射到 token 序列中。具体过程如下:
- 将输入图像切分成大小为 \(P \times P\) 的 patches。
- 对每个 patch 进行线性嵌入操作,得到初始的 token 表示。
- 使用卷积下采样层进一步减少空间分辨率,形成多尺度特征图。
此阶段通常称为“阶段 1”,其中 transformer 块的数量为 \(H/4 \times W/4\),即每张图片被分解成若干 tokens[^2]。
#### 3. 移位窗口机制
为了提高效率和建模能力,Swin Transformer 引入了移位窗口策略。在标准窗口划分的基础上,每隔一层会调整窗口的位置以引入交叉窗口的信息交互。这种方法显著提升了性能,在 ImageNet-1K 图像分类任务上 top-1 准确率提高了 +1.1%,而在 COCO 目标检测任务中则分别提升 +2.8 box AP 和 +2.2 mask AP[^4]。
#### 4. PyTorch 实现代码
以下是基于 PyTorch 的 Swin Transformer 核心模块实现:
```python
import torch
from torch import nn
class PatchEmbed(nn.Module):
"""Patch Embedding Layer"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
"""Multilayer Perceptron"""
def __init__(self, in_features, hidden_features=None, out_features=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class WindowAttention(nn.Module):
"""Window-based Multi-head Self Attention (MSA) module with relative position bias."""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.attn_drop = nn.Dropout(0.)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
super().__init__()
self.input_resolution = input_resolution
self.window_size = window_size
self.shift_size = shift_size
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(
dim,
window_size=(self.window_size, self.window_size),
num_heads=num_heads
)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * 4)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = shifted_x.unfold(1, self.window_size, self.window_size)\
.unfold(2, self.window_size, self.window_size)
x_windows = x_windows.contiguous().view(-1, self.window_size*self.window_size, C)
# attention and projection
attn_windows = self.attn(x_windows)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse windows
shifted_x = attn_windows.permute(0, 1, 2, 3).contiguous().view(B, H, W, C)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H*W, C)
# FFN
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage."""
def __init__(self, dim, depth, num_heads, window_size=7):
super().__init__()
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
input_resolution=(window_size, window_size),
num_heads=num_heads,
window_size=window_size,
shift_size=0 if i % 2 == 0 else window_size // 2
) for i in range(depth)])
def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x
class SwinTransformer(nn.Module):
"""Overall architecture of the Swin Transformer model."""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7):
super().__init__()
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))]
self.layers = nn.ModuleList()
for i_layer in range(len(depths)):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size
)
self.layers.append(layer)
self.norm = nn.LayerNorm(int(embed_dim * 2 ** (len(depths)-1)))
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(int(embed_dim * 2 ** (len(depths)-1)), num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x):
x = self.patch_embed(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x.mean(1))
x = self.head(x)
return x
```
#### 5. 训练与验证流程
训练过程中可以使用常见的优化算法(如 Adam 或 SGD),并通过学习率调度器动态调整超参数。对于下游任务(如目标检测或语义分割),可以通过微调预训练权重来加速收敛。
---