SwinTransformer细节及代码实现(pytorch版本)

paper: https://arxiv.org/abs/2103.14030
原版code: https://github.com/microsoft/Swin-Transformer

作者分析表明,Transformer从NLP迁移到CV上没有大放异彩主要有两点原因:

  1. 两个领域涉及的scale不同,NLP的scale是标准固定的,而CV的scale变化范围非常大。
  2. CV比起NLP需要更大的分辨率,而且CV中使用Transformer的计算复杂度是图像尺度的平方,这会导致计算量过于庞大。为了解决这两个问题,Swin Transformer相比之前的ViT做了两个改进:1.引入CNN中常用的层次化构建方式构建层次化Transformer 2.引入locality思想,对无重合的window区域内进行self-attention计算。
    3.

一、文章简介

如图1(a)所示,Swin Transformer通过从小尺寸的patches(以灰色框)开始,并逐渐将相邻patches合并到更深的Transformer层中来构建层次表示。有了这些分层特征映射,Swin Transformer模型可以方便地利用高级技术进行密集预测,如特征金字塔网络(FPN)或U-Net。线性计算复杂性是通过在分割图像的非重叠窗口(用红色标出)内局部计算自我注意来实现的。每个窗口中的patches数是固定的,因此复杂性与图像大小成线性关系。这些优点使Swin Transformer适合作为各种视觉任务的通用主干,这与以前基于Transformer的体系结构]不同,后者生成单一分辨率的特征图,并且具有二次复杂性。
Swin Transformer的一个关键设计元素是在连续的自注意力层之间切换窗口分区,如下图所示。移动的窗口桥接了前一层的窗口,提供了它们之间的连接,显著增强了建模能力。这种策略对于真实世界的延迟也是有效的:一个窗口中的所有query patches都共享相同的 key set,这有助于硬件中的内存访问。相比之下,早期基于滑动窗口的自注意力在通用硬件上的延迟较低,因为不同query像素的key sets不同。移位窗口方法的延迟比滑动窗口方法低得多,但建模能力相似。事实证明,移位窗口方法也适用于所有MLP体系结构。
在这里插入图片描述
在所提出的Swin Transformer架构中,上图一个用于计算自注意力的移位窗口方法的示例。在l层(左),采用规则的窗口划分方案,并在每个窗口内计算自注意力。在下一层l+1(右)中,窗口分区被移动,从而产生新的窗口。新窗口中的自我注意计算跨越了层l中以前窗口的边界,提供了它们之间的连接。

二、整体实现

在这里插入图片描述
Transformer体系结构的概述,它展示了微型版本(SwinT)。它首先通过patch splitting module(如ViT)将输入的RGB图像分割为非重叠 patch。每个 patch被视为一个“token”,其特征被设置为原始像素RGB值之间的串联。在我们的实现中,我们使用4×4的patch大小,因此每个patch的特征维数为4×4×3=48。在该原始值特征上应用线性嵌入层,将其投影到任意维度(表示为C),patch块的数量为H/4 x W/4。

class PatchEmbed(nn.Module):
    r""" patch splitting 

    Args:
        img_size (int): 输入图片的尺寸
        patch_size (int): Patch token的尺度. Default: 4.
        in_chans (int): 输入通道的数量
        embed_dim (int): 线性投影后输出的维度
        norm_layer (nn.Module, optional): 这可以进行设置,在本层中设置为了None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=512, embed_dim=96, norm_layer=None):
        super().__init__()
        # 先将img_size、patch_size转化为元组模式(224 , 224) 、 (4 , 4)
        img_size = to_2tuple(img_size) 
        patch_size = to_2tuple(patch_size)
        # 计算出 Patch token在长宽方向上的数量
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        # 计算出patch的数量利用Patch token在长宽方向上的数量相乘的结果
        self.num_patches = patches_resolution[0] * patches_resolution[1]
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 判断是否使用norm_layer,在这里我们没有应用
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
    	# 解析输入的维度
        B, C, H, W = x.shape
        # 判断图像是否与设定图像一致,如果不一致会报错
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # 经过一个卷积层来进行一个线性变换,并且在第二个维度上进行一个压平操作,维度为(B, C, Ph*Pw),后在进行一个维与二维的一个转置,维度为:(B Ph*Pw C)
        x = self.proj(x).flatten(2).transpose(1, 2)  
        if self.norm is not None:
            x = self.norm(x)
        return x

在这些patch tokens上应用了几个Swin Transformer blocks的bolck。Transformer blocks 与patch tokens数量(h/4×w/4)一致,与线性嵌入一起被称为“阶段1”。

为了产生分层表示,随着网络的深入,通过patch合并层来减少patch tokens的数量。第一个patch合并层连接每个2×2相邻patch的特征(如下图红色框内的patch将合并成一组),并在4C维连接的特征上应用线性层。这将patch tokens的数量减少了2×2=4的倍数(分辨率的2×降采样),并且输出维度设置为2C。

lass PatchMerging(nn.Module):
    r""" Patch合并.

    Args:
        input_resolution (tuple[int]): 输入特征的分辨率.
        dim (int): 输入通道的数量
        norm_layer (nn.Module, optional): 定义Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        # 通过一个线性层将4C降为2C
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        # 解析输入图像的分辨率,即输入图像的长宽
        H, W = self.input_resolution
        # 解析输入图像的维度
        B, L, C = x.shape
        # 判断L是否与H * W一致,如不一致会报错
        assert L == H * W, "input feature has wrong size"
        # 判断输入图像的长宽是否可以被二整除,因为我们是通过2倍来进行下采样的
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
		# 将xreshape为维度:(B, H, W, C)
        x = x.view(B, H, W, C)
		# 切片操作,通过切片操作将将相邻的2*2的patch进行拼接
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        # 将合并好的patch通过c维进行拼接
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        # 将x的维度重置为B H/2*W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
		
        x = self.norm(x)
        # 通过一个线性层进通道降维
        x = self.reduction(x)

        return x

然后应用Swin Transformer blocks进行特征变换,分辨率保持在h/8×w/8。第一个patch merging和特征转换块被称为“阶段2”。该过程重复两次,分别为“阶段3”和“阶段4”,输出分辨率分别为H/16×W/16和H/32×W/32。这些阶段共同产生了一种阶段的表示。一种具有与典型卷积网络相同的特征图分辨率,例如VGG和ResNe。因此,所提出的体系结构可以方便地替代现有方法中用于各种视觉任务的主干网络。
下面先实现一下下图中的代码:
在这里插入图片描述

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): 输入维度的数量
        input_resolution (tuple[int]): 输入图像的分辨率
        num_heads (int): 应用注意力头的数量
        window_size (int): 窗口的尺寸
        shift_size (int): SW-MSA的循环位移的大小,默认为零
        mlp_ratio (float): mlp的隐层的比例,默认为零
        qkv_bias (bool, optional): 是否应用位移偏量,具体实现在下文
        qk_scale (float | None, optional
  • 21
    点赞
  • 203
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值