Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 

目录

 

一、Swin Transformer的整体架构

1. 输入图像分块

 2. 分层表示

 3. 分层结构

4. 移位窗口自注意力机制

5. 相对位置偏置

 6. 分层设计的优势

二、代码实现


一、Swin Transformer的整体架构

    本文提出了一种新的视觉Transformer,称为Swin TransformerShifted Window Transformer)。该模型解决了视觉实体尺度巨大变化图像像素高分辨率带来的挑战。Swin Transformer通过移位窗口计算表示,采用分层的Transformer架构,提高了效率,并能在多种尺度上进行建模。其计算复杂度相对于图像大小是线性的,适用于图像分类、目标检测和语义分割等视觉任务,显著超过了之前的最先进技术。 

 

1. 输入图像分块

     Swin Transformer首先将输入的RGB图像分割成非重叠的小块(patches),每个小块大小为4×4像素。每个小块作为一个“token”,其特征是原始像素RGB值的拼接。然后,通过一个线性嵌入层将这些特征投影到一个任意维度(记作C),例如96维(Swin-T)或128维(Swin-B)。

 2. 分层表示

      Swin Transformer通过逐层合并邻近的小块来构建分层表示。这一过程通过“patch merging”层实现在每一层合并2×2的相邻小块,生成新的token,并通过线性层进行特征变换。例如,合并后的特征维度可以变为原来的2倍(如从C到2C)

 3. 分层结构

Swin Transformer的分层结构由四个阶段组成,每个阶段生成不同分辨率的特征

Stage 1:输入为原始图像分块,生成初始的token特征。特征图大小为[H/4, W/4]

Stage 2:通过合并相邻token,将特征图分辨率降低一半(2×下采样),生成[H/8, W/8]特征图。

Stage 3:继续合并token,生成[H/16, W/16]的特征图。

Stage 4:进一步合并token,生成[H/32, W/32]的特征图。

4. 移位窗口自注意力机制

每个阶段的特征变换通过Swin Transformer块实现,该块采用移位窗口自注意力机制(Shifted Window Self-Attention)。自注意力计算在非重叠的局部窗口内进行,每个窗口包含固定数量的token,从而使计算复杂度相对于图像大小是线性的。在连续的Transformer块中,窗口划分方式在每一层之间进行移位,使得跨窗口的连接得以实现,增强了模型的建模能力。

具体来说:

W-MSA(Window based Multi-head Self Attention)在固定窗口内计算自注意力。

SW-MSA(Shifted Window based Multi-head Self Attention):窗口划分方式移位后计算自注意力,形成跨窗口连接。

5. 相对位置偏置

在自注意力计算中引入相对位置偏置(Relative Position Bias),进一步提升了模型的性能。相对位置偏置考虑了token之间的相对距离,增强了模型对局部和全局信息的捕捉能力。

 6. 分层设计的优势

Swin Transformer的分层设计具有的优势:

多尺度处理能力:通过逐层合并token,能够有效处理不同尺度的视觉实体。

高效计算:移位窗口机制使得自注意力计算限制在局部窗口内,保持计算效率。

灵活适应:适用于各种视觉任务,包括图像分类、目标检测和语义分割。

二、代码实现

 1.MLP :多层感知机,是一种前馈神经网络

 

# 简单的多层感知机MLP 、 用于特征提取或其他模型的构建块
# 通过输入特征,经过隐藏层和激活函数后,生成输出
class Mlp(nn.Module): # 输入特征数  , 隐藏层特征数  ,        输出特征数  ,        激活函数层  ,      dropout 比率
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        # 如果 out_features 或 hidden_features 没有提供,则默认为 in_features
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # 定义网络层 一个线性层fc1 、 激活函数层act 、 线性层fc2 、 dropout 层 drop
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

2.window_partition 函数将输入张量 x 分割成多个窗口

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    '''
     H // window_size, window_size,  图像的高度被 window_size 分割后的窗口数目 (height // window_size),即每个批次中可以容纳多少个 window_size 高的窗口
     W // window_size, window_size,  图像的宽度被 window_size 分割后的窗口数目 (width // window_size),即每个批次中可以容纳多少个 window_size 宽的窗口
    
     windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
     permute 调整张量的维度顺序,以便窗口在最后两个维度中,即 (B, H // window_size, W // window_size, window_size, window_size, C)。contiguous 确保内存中的数据排列是连续的
     windows = x.view(-1, window_size, window_size, C)
     张量展平为 (num_windows*B, window_size, window_size, C),其中 num_windows 是每个批次中窗口的数量
     return  最终返回的 windows 张量包含了图像的所有窗口,每个窗口的形状是 (window_size, window_size, C)
    '''
    B, H, W, C = x.shape
    # H 和 W 被分成了两个维度,分别代表窗口的数量和每个窗口的大小,从而增加了两个维度
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # 将原始图像划分为多个 window_size x window_size 的窗口来扩展维度
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

3.window_reverse : 将窗口形式的张量恢复为原始图像的形状

# 将窗口形式的张量恢复为原始图像的形状
def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    # 通过将窗口总数 (windows.shape[0]) 除以每张图像的窗口数目 (H * W / window_size / window_size) 来计算批次大小 B
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # -1 是一个特殊的占位符,表示让 PyTorch 自动计算该维度的大小。这里,-1 让 PyTorch 根据其他维度的大小自动推断出通道数 C,确保张量的总元素数目保持不变
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # -1 表示通道数 C 被自动推断出来
    return x

 4. Patch Embedding : 图像切分为补丁并进行嵌入

# 将图像切分为补丁并进行嵌入。(Patch Embedding)
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): 图像大小 Image size.  Default: 224.
        patch_size (int): 补丁大小 Patch token size. Default: 4.
        in_chans (int): 输入通道数 Number of input image channels. Default: 3.
        embed_dim (int): 嵌入维度 Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): 归一化层 Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        # 将 img_size 和 patch_size 转换为二元组。如果输入为单一整数,则会转化为 (value, value) 形式
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        # 计算补丁的分辨率。每个维度上的补丁数量
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size # 图像大小
        self.patch_size = patch_size # 补丁大小
        self.patches_resolution = patches_resolution # 不定分辨率
        self.num_patches = patches_resolution[0] * patches_resolution[1] # 补丁总数

        self.in_chans = in_chans # 输入通道数
        self.embed_dim = embed_dim # 嵌入维度
        # 定义卷积层,将每个补丁的像素点投影到嵌入维度空间中  卷积的 kernel_size 和 stride 都设置为 patch_size,即对每个补丁进行处理
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 如果提供了归一化层,则创建该层并将其应用于嵌入维度,否则不使用归一化层
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # 通过卷积层处理图像,将其形状从 (B, C, H, W) 转换为 (B, num_patches, embed_dim)
        # flatten(2) 将最后两个维度展平成补丁的数量,transpose(1, 2) 调整维度顺序以便与 Transformer 等模型兼容
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        # 计算卷积操作的 FLOPs。每个补丁的计算量为 embed_dim * in_chans * (patch_size[0] * patch_size[1]),然后乘以补丁的数量 Ho * Wo。
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        # 如果有归一化层,计算其 FLOPs。每个补丁需要 embed_dim 次操作,然后乘以补丁的数量
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

 5.PatchMerging : 图像合并

# 图像块(patches)合并
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution # 输入特征图的分辨率(高和宽)
        self.dim = dim # 输入特征图的通道数。
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 线性变换层,用于减少通道数(从 4 * dim 到 2 * dim)
        self.norm = norm_layer(4 * dim) # 归一化层,用于规范化特征图。

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape    # H*W 是特征图的总像素数
        # 验证输入特征图的尺寸是否符合预期
        # 确保输入特征图的展平形式的总元素数量与原始的高(H)和宽(W)的乘积一致。
        assert L == H * W, "input feature has wrong size" # 如果不等,则说明输入特征图的尺寸与预期不符,会抛出错误信息 "input feature has wrong size"。
        # 确保输入特征图的高和宽都是偶数。
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        # 从每个 H x W 的特征图中提取四个不同的子块。每个子块的大小是 [B, H/2, W/2, C]
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        # 将四个子块沿通道维度拼接,形成新的特征图
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x) # 通过线性层(self.reduction)减少通道数,从 4*C 到 2*C

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim # 假设有 H * W 个位置,每个位置涉及 self.dim 次浮点运算。这个值通常用于计算在图像上进行某些操作(如卷积)时的 FLOPs
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops
        """
           (H // 2) * (W // 2):表示下采样后的图像分辨率(假设下采样因子为2),即每个维度缩小一半后的新高度和宽度。
            4:可能表示有4个操作,或者某种操作的系数。
            self.dim * 2 * self.dim:表示每个位置涉及的浮点运算量。
            这里的 2 * self.dim 可能表示某种计算,例如两个矩阵乘法操作(self.dim 为矩阵的维度),每个操作有 self.dim 次运算
        """

6.BasicLayer : 处理输入数据的处理和转换

"""
BasicLayer 类用于构建 Transformer 模型中的基本层,特别是在类似 Swin Transformer 的架构中。
它定义了多个 Transformer 块的堆叠,并可选择性地包括下采样层。它处理输入数据的处理和转换,
支持设置各种超参数,如注意力头数、窗口大小和 dropout 比例,以实现有效的特征提取和变换。
"""
class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int):  输入特征的维度->通道数                                     Number of input channels.
        input_resolution (tuple[int]): 输入图像的分辨率                       Input resolution.
        depth (int):  Transformer 层的深度,即包含多少个 SwinTransformerBlock   Number of blocks.
        num_heads (int): 自注意力机制中的头数                                    Number of attention heads.
        window_size (int): 窗口的大小,用于局部自注意力                             Local window size.
        mlp_ratio (float): MLP 部分的比率,控制前馈网络的隐藏层维度                  Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): 是否在自注意力的 Q、K、V 矩阵中使用偏置            If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional):  缩放因子,用于自注意力中的 Q 和 K。        Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout 概率,用于前馈网络                          Dropout rate. Default: 0.0
        attn_drop (float, optional): Dropout 概率,用于自注意力机制                 Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional):  Drop path 概率,用于模型训练时的正则化  Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default:   归一化层的类型        nn.LayerNorm
        downsample (nn.Module | None, optional): 一个可选的下采样模块,用于改变特征图的分辨率           Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. 是否使用检查点技术以节省内存        Default: False.
        fused_window_process (bool, optional): 是否使用融合的窗口处理方法,以优化计算效率             If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """
    """
    blocks: 由多个 SwinTransformerBlock 组成的 nn.ModuleList,每个块的 shift_size 取决于其在深度中的位置。
    downsample: 如果 downsample 不为空,则构造一个下采样模块,用于调整特征图的分辨率。
    """
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 fused_window_process=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks 创建了一个包含多个 SwinTransformerBlock 的 nn.ModuleList
        # 通过列表推导式生成每一层 SwinTransformerBlock,并根据其在网络中的位置调整 shift_size 和 drop_path。
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 fused_window_process=fused_window_process)
            for i in range(depth)])

        # patch merging layer
        # 如果提供了 downsample 函数,则创建下采样层 self.downsample。否则,将 self.downsample 设置为 None。
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        # 对输入 x 应用所有块。如果 use_checkpoint 为 True,则使用检查点来节省内存。否则,直接应用块
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        # 如果存在下采样层,则对结果应用下采样。返回最终的输出 x。
        if self.downsample is not None:
            x = self.downsample(x)
        return x
    # 返回类的额外信息,帮助了解模型的维度、输入分辨率和深度。
    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
    # FLOPs 计算方法
    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

7.WindowAttention

8.SwinTransformerBlock

9.SwinTransformer

  • 13
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值