目录
一、Swin Transformer的整体架构
本文提出了一种新的视觉Transformer,称为Swin Transformer(Shifted Window Transformer)。该模型解决了视觉实体尺度巨大变化和图像像素高分辨率带来的挑战。Swin Transformer通过移位窗口计算表示,采用分层的Transformer架构,提高了效率,并能在多种尺度上进行建模。其计算复杂度相对于图像大小是线性的,适用于图像分类、目标检测和语义分割等视觉任务,显著超过了之前的最先进技术。
1. 输入图像分块
Swin Transformer首先将输入的RGB图像分割成非重叠的小块(patches),每个小块大小为4×4像素。每个小块作为一个“token”,其特征是原始像素RGB值的拼接。然后,通过一个线性嵌入层将这些特征投影到一个任意维度(记作C),例如96维(Swin-T)或128维(Swin-B)。
2. 分层表示
Swin Transformer通过逐层合并邻近的小块来构建分层表示。这一过程通过“patch merging”层实现,在每一层合并2×2的相邻小块,生成新的token,并通过线性层进行特征变换。例如,合并后的特征维度可以变为原来的2倍(如从C到2C)
3. 分层结构
Swin Transformer的分层结构由四个阶段组成,每个阶段生成不同分辨率的特征
Stage 1:输入为原始图像分块,生成初始的token特征。特征图大小为[H/4, W/4]。
Stage 2:通过合并相邻token,将特征图分辨率降低一半(2×下采样),生成[H/8, W/8]特征图。
Stage 3:继续合并token,生成[H/16, W/16]的特征图。
Stage 4:进一步合并token,生成[H/32, W/32]的特征图。
4. 移位窗口自注意力机制
每个阶段的特征变换通过Swin Transformer块实现,该块采用移位窗口自注意力机制(Shifted Window Self-Attention)。自注意力计算在非重叠的局部窗口内进行,每个窗口包含固定数量的token,从而使计算复杂度相对于图像大小是线性的。在连续的Transformer块中,窗口划分方式在每一层之间进行移位,使得跨窗口的连接得以实现,增强了模型的建模能力。
具体来说:
W-MSA(Window based Multi-head Self Attention):在固定窗口内计算自注意力。
SW-MSA(Shifted Window based Multi-head Self Attention):窗口划分方式移位后计算自注意力,形成跨窗口连接。
5. 相对位置偏置
在自注意力计算中引入相对位置偏置(Relative Position Bias),进一步提升了模型的性能。相对位置偏置考虑了token之间的相对距离,增强了模型对局部和全局信息的捕捉能力。
6. 分层设计的优势
Swin Transformer的分层设计具有的优势:
多尺度处理能力:通过逐层合并token,能够有效处理不同尺度的视觉实体。
高效计算:移位窗口机制使得自注意力计算限制在局部窗口内,保持计算效率。
灵活适应:适用于各种视觉任务,包括图像分类、目标检测和语义分割。
二、代码实现
1.MLP :多层感知机,是一种前馈神经网络
# 简单的多层感知机MLP 、 用于特征提取或其他模型的构建块
# 通过输入特征,经过隐藏层和激活函数后,生成输出
class Mlp(nn.Module): # 输入特征数 , 隐藏层特征数 , 输出特征数 , 激活函数层 , dropout 比率
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
# 如果 out_features 或 hidden_features 没有提供,则默认为 in_features
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# 定义网络层 一个线性层fc1 、 激活函数层act 、 线性层fc2 、 dropout 层 drop
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
2.window_partition 函数将输入张量 x 分割成多个窗口
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
'''
H // window_size, window_size, 图像的高度被 window_size 分割后的窗口数目 (height // window_size),即每个批次中可以容纳多少个 window_size 高的窗口
W // window_size, window_size, 图像的宽度被 window_size 分割后的窗口数目 (width // window_size),即每个批次中可以容纳多少个 window_size 宽的窗口
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
permute 调整张量的维度顺序,以便窗口在最后两个维度中,即 (B, H // window_size, W // window_size, window_size, window_size, C)。contiguous 确保内存中的数据排列是连续的
windows = x.view(-1, window_size, window_size, C)
张量展平为 (num_windows*B, window_size, window_size, C),其中 num_windows 是每个批次中窗口的数量
return 最终返回的 windows 张量包含了图像的所有窗口,每个窗口的形状是 (window_size, window_size, C)
'''
B, H, W, C = x.shape
# H 和 W 被分成了两个维度,分别代表窗口的数量和每个窗口的大小,从而增加了两个维度
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # 将原始图像划分为多个 window_size x window_size 的窗口来扩展维度
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
3.window_reverse : 将窗口形式的张量恢复为原始图像的形状
# 将窗口形式的张量恢复为原始图像的形状
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
# 通过将窗口总数 (windows.shape[0]) 除以每张图像的窗口数目 (H * W / window_size / window_size) 来计算批次大小 B
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
# -1 是一个特殊的占位符,表示让 PyTorch 自动计算该维度的大小。这里,-1 让 PyTorch 根据其他维度的大小自动推断出通道数 C,确保张量的总元素数目保持不变
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # -1 表示通道数 C 被自动推断出来
return x
4. Patch Embedding : 图像切分为补丁并进行嵌入
# 将图像切分为补丁并进行嵌入。(Patch Embedding)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): 图像大小 Image size. Default: 224.
patch_size (int): 补丁大小 Patch token size. Default: 4.
in_chans (int): 输入通道数 Number of input image channels. Default: 3.
embed_dim (int): 嵌入维度 Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): 归一化层 Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
# 将 img_size 和 patch_size 转换为二元组。如果输入为单一整数,则会转化为 (value, value) 形式
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
# 计算补丁的分辨率。每个维度上的补丁数量
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size # 图像大小
self.patch_size = patch_size # 补丁大小
self.patches_resolution = patches_resolution # 不定分辨率
self.num_patches = patches_resolution[0] * patches_resolution[1] # 补丁总数
self.in_chans = in_chans # 输入通道数
self.embed_dim = embed_dim # 嵌入维度
# 定义卷积层,将每个补丁的像素点投影到嵌入维度空间中 卷积的 kernel_size 和 stride 都设置为 patch_size,即对每个补丁进行处理
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# 如果提供了归一化层,则创建该层并将其应用于嵌入维度,否则不使用归一化层
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# 通过卷积层处理图像,将其形状从 (B, C, H, W) 转换为 (B, num_patches, embed_dim)
# flatten(2) 将最后两个维度展平成补丁的数量,transpose(1, 2) 调整维度顺序以便与 Transformer 等模型兼容
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
# 计算卷积操作的 FLOPs。每个补丁的计算量为 embed_dim * in_chans * (patch_size[0] * patch_size[1]),然后乘以补丁的数量 Ho * Wo。
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
# 如果有归一化层,计算其 FLOPs。每个补丁需要 embed_dim 次操作,然后乘以补丁的数量
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
5.PatchMerging : 图像合并
# 图像块(patches)合并
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution # 输入特征图的分辨率(高和宽)
self.dim = dim # 输入特征图的通道数。
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 线性变换层,用于减少通道数(从 4 * dim 到 2 * dim)
self.norm = norm_layer(4 * dim) # 归一化层,用于规范化特征图。
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape # H*W 是特征图的总像素数
# 验证输入特征图的尺寸是否符合预期
# 确保输入特征图的展平形式的总元素数量与原始的高(H)和宽(W)的乘积一致。
assert L == H * W, "input feature has wrong size" # 如果不等,则说明输入特征图的尺寸与预期不符,会抛出错误信息 "input feature has wrong size"。
# 确保输入特征图的高和宽都是偶数。
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
# 从每个 H x W 的特征图中提取四个不同的子块。每个子块的大小是 [B, H/2, W/2, C]
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
# 将四个子块沿通道维度拼接,形成新的特征图
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x) # 通过线性层(self.reduction)减少通道数,从 4*C 到 2*C
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim # 假设有 H * W 个位置,每个位置涉及 self.dim 次浮点运算。这个值通常用于计算在图像上进行某些操作(如卷积)时的 FLOPs
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
"""
(H // 2) * (W // 2):表示下采样后的图像分辨率(假设下采样因子为2),即每个维度缩小一半后的新高度和宽度。
4:可能表示有4个操作,或者某种操作的系数。
self.dim * 2 * self.dim:表示每个位置涉及的浮点运算量。
这里的 2 * self.dim 可能表示某种计算,例如两个矩阵乘法操作(self.dim 为矩阵的维度),每个操作有 self.dim 次运算
"""
6.BasicLayer : 处理输入数据的处理和转换
"""
BasicLayer 类用于构建 Transformer 模型中的基本层,特别是在类似 Swin Transformer 的架构中。
它定义了多个 Transformer 块的堆叠,并可选择性地包括下采样层。它处理输入数据的处理和转换,
支持设置各种超参数,如注意力头数、窗口大小和 dropout 比例,以实现有效的特征提取和变换。
"""
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): 输入特征的维度->通道数 Number of input channels.
input_resolution (tuple[int]): 输入图像的分辨率 Input resolution.
depth (int): Transformer 层的深度,即包含多少个 SwinTransformerBlock Number of blocks.
num_heads (int): 自注意力机制中的头数 Number of attention heads.
window_size (int): 窗口的大小,用于局部自注意力 Local window size.
mlp_ratio (float): MLP 部分的比率,控制前馈网络的隐藏层维度 Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): 是否在自注意力的 Q、K、V 矩阵中使用偏置 If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): 缩放因子,用于自注意力中的 Q 和 K。 Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout 概率,用于前馈网络 Dropout rate. Default: 0.0
attn_drop (float, optional): Dropout 概率,用于自注意力机制 Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Drop path 概率,用于模型训练时的正则化 Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: 归一化层的类型 nn.LayerNorm
downsample (nn.Module | None, optional): 一个可选的下采样模块,用于改变特征图的分辨率 Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. 是否使用检查点技术以节省内存 Default: False.
fused_window_process (bool, optional): 是否使用融合的窗口处理方法,以优化计算效率 If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
"""
blocks: 由多个 SwinTransformerBlock 组成的 nn.ModuleList,每个块的 shift_size 取决于其在深度中的位置。
downsample: 如果 downsample 不为空,则构造一个下采样模块,用于调整特征图的分辨率。
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks 创建了一个包含多个 SwinTransformerBlock 的 nn.ModuleList
# 通过列表推导式生成每一层 SwinTransformerBlock,并根据其在网络中的位置调整 shift_size 和 drop_path。
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
fused_window_process=fused_window_process)
for i in range(depth)])
# patch merging layer
# 如果提供了 downsample 函数,则创建下采样层 self.downsample。否则,将 self.downsample 设置为 None。
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
# 对输入 x 应用所有块。如果 use_checkpoint 为 True,则使用检查点来节省内存。否则,直接应用块
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
# 如果存在下采样层,则对结果应用下采样。返回最终的输出 x。
if self.downsample is not None:
x = self.downsample(x)
return x
# 返回类的额外信息,帮助了解模型的维度、输入分辨率和深度。
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
# FLOPs 计算方法
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
7.WindowAttention
8.SwinTransformerBlock
9.SwinTransformer