【笔记】Swin-Transformer 的计算量与Transformer的计算量的对比:前者通过使用新颖的窗口技巧,将后者的高阶项变为低阶,大大降低了计算量

补充1:

  • 局部窗口内的自注意力(W-MSA):

    • 在 Swin Transformer 中,输入特征图被划分为多个小的窗口(例如 7x7 的窗口)。在每个窗口内,计算自注意力机制(W-MSA, Window-based Multi-Head Self-Attention),这意味着每个 token 只和同一窗口内的其他 token 进行交互。
    • 由于计算只发生在局部窗口内,所以计算复杂度大大降低,特别是对于高分辨率的输入图像来说,这种方式更加高效。
  • 滑动窗口机制(Shifted Window Attention):

    • 为了在局部窗口之间传递信息,Swin Transformer 引入了滑动窗口机制。通过在不同的层中移动窗口的位置,使得相邻窗口之间的特征可以进行交流,从而保证全局上下文的信息逐步整合到特征中。
  • 计算量的比较:

    • 全局自注意力(MSA):像 Vision Transformer (ViT) 这样的方法在整个特征图上计算自注意力,计算复杂度是 O((hw)²)。
    • 窗口内自注意力(W-MSA):Swin Transformer 仅在每个窗口内计算,计算复杂度降低为 O(W * M²),其中 W 是窗口的数量,M 是窗口内的 token 数量(例如 7x7 = 49 个 token)。
  • 滑动窗口的好处:

    • 滑动窗口机制允许信息在不同窗口之间传递,而不仅仅局限在窗口内部。这种设计平衡了计算效率和特征提取的全局性,确保 Swin Transformer 可以在较低的计算复杂度下仍然获得良好的表现。

补充2:

关于特征图大小的解释:

  • 输入图像大小(224x224):

    • 在大多数计算机视觉任务中,输入图像通常会被调整为 224x224 像素。
  • Patch Embedding 和 Stride=16:

    • Vision Transformer (ViT) 通常会将输入图像划分为 16x16 的 non-overlapping patches,然后将每个 patch 展平并映射到一个高维的特征空间。
    • 因为每个 patch 的大小是 16x16,且是 non-overlapping 的,这相当于对输入图像应用了一个 Stride=16 的卷积操作,将图像的空间分辨率从 224x224 减少到 14x14。
  • 特征图大小(14x14):

    • 因此,经过 Stride=16 的操作后,原始图像被划分为 14x14 个 patch,每个 patch 被视为一个 token。在 Vision Transformer 中,这 14x14 个 token 会形成一个 196 维的 token 序列。

Swin Transformer 的不同点:

Swin Transformer 在一些细节上和 ViT 有所不同:

  • 多级特征图:
    • Swin Transformer 处理的是逐级降低空间分辨率的特征图(类似于卷积神经网络中的多尺度特征),比如从最开始的较大特征图(例如 h=w=56)到最后的较小特征图。
  • 滑动窗口与局部自注意力:
    • 在 Swin Transformer 中,通过窗口内自注意力(W-MSA)和滑动窗口(Shifted Window)机制来逐步处理这些特征图,计算局部区域内的自注意力。

图中的特征图大小与实际应用:

在你提供的图片中,h=w=56 可能指的是在 Swin Transformer 的某个阶段,特征图被处理时的空间分辨率。例如,在较早的阶段,特征图的空间分辨率较高,经过几次降采样后,可能从 224x224 降到 56x56 甚至更低。

因此,特征图的大小 (14x14 或 56x56) 取决于模型的阶段以及具体的网络结构。在 Swin Transformer 中,早期层的特征图可能较大,而后期层的特征图可能较小,这与 Vision Transformer 中固定的 14x14 特征图有所不同。

补充3:

关于正文中h=w=56, m=7 的补充:

在 Swin Transformer 中,h=w=56m=7 是针对特定阶段的特征图大小和窗口大小。这些参数在 Swin Transformer 中是有具体含义的:

1. h=w=56 的解释

  • 初始阶段的特征图大小:
    • Swin Transformer 通常会通过多级特征提取器(类似卷积神经网络中的多尺度特征提取),逐步缩小特征图的空间分辨率。
    • 例如,在初始阶段,输入图像可能会被划分成大小为 4x4 的 patch(相当于 Stride=4),并将输入图像从原始的 224x224 分辨率降采样到 56x56 的特征图。
    • 具体来说,224x224 的输入图像通过 Stride=4 的操作后,特征图的大小变成 224/4 = 56,既 h=w=56

2. m=7 的解释

  • 窗口大小(Window Size):
    • Swin Transformer 的一个关键特点是,它引入了基于窗口的多头自注意力(W-MSA),这个窗口是在特定大小的局部区域内进行自注意力计算的。
    • m=7 表示窗口的大小为 7x7,也就是说,在每一个 7x7 的局部区域内计算自注意力,而不是在整个 56x56 的全局上计算。
    • 通过将大的特征图(例如 56x56)划分为多个 7x7 的窗口,Swin Transformer 可以在保持计算量可控的前提下,捕捉局部的相关性。

Swin Transformer 的多级结构

Swin Transformer 的网络结构通常分为多个阶段,每个阶段的特征图大小和窗口大小可能有所不同:

  • Stage 1: 假设输入图像为 224x224,通过 Stride=4 的 patch embedding 操作,特征图的大小变为 56x56。
  • Stage 2: 在 Stage 1 处理后的 56x56 特征图基础上,应用 7x7 的窗口来进行局部自注意力计算。
  • Stage 3: 特征图继续下采样到更小的分辨率(例如 28x28 或 14x14),然后继续应用更小的窗口进行计算。

注1:

注2:


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
 
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
 
    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        x = x.view(B, H, W, C)
 
        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
 
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
 
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]
 
        return x

正文:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序猿的探索之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值