Swin Transformer

论文链接: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

代码链接

作为backbone,用于各种下游任务,提供预训练模型,对标resnet。

swin transformer block:W-MSA+SW-MSA

W-MSA:Window Multi-head Self-Attention,窗口注意力机制

SW-MSA:Shifted Window Multi-head Self-Attention,滑动窗口注意力机制

分层特征图(hierarchical feature maps

分层特征图是通过无卷积下采样来构建的,称为 Patch Merging

转移窗口注意力(shifted window attention

论文和程序中都要注意三个不同的概念,分别是resolution/ patches/ windows

resolution: 输入图片的分辨率是像素分辨率,程序Part 1 输入图片是像素分辨率,但是Part 2程序中对应的H/W是patches 分辨率,不是像素分辨率。

patches: 图像4*4像素区域称为一个patch,分类任务输入图像像素分辨率是224*224,patch_size = 4,所以patches__resolution = 56*56

windows: 窗口大小由patches定义的,不是像素定义的,论文及程序中window_size = 7,说明一个window有7*7=49个patches

Part 1 :图像处理

Part 1包含两部分,Patch Partition 将像素分辨率图像转换为patches分辨率的图像,每个patch视为一个token,特征就是patch范围内的RGB值的展开,token_feature = 48;Linear Embedding 将token_feature转换为需要的维度(Swin_T/C=96) ,以上两个部分在程序中通过PatchEmbed类同时实现:

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.   Swin_T.C = Swin_S.C = 96  Swin_B.C = 128  Swin_L.C = 192
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    
    return:
        (3136,96):序列长度是3136((224/4)*(224/4))个,每个向量是96维特征
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()

        img_size = to_2tuple(img_size)  # img_size = (img_size, img_size)
        patch_size = to_2tuple(patch_size) # patch_size = (patch_size, patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size  # self.img_size = (224, 224)
        self.patch_size = patch_size # self.patch_size = (4, 4)
        self.patches_resolution = patches_resolution # self.patches_resolution = (56, 56)
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans  # 3
        self.embed_dim = embed_dim # 96
        #                        3          96               (4,4)                (4,4)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        '''
        输入图像分辨率必须为(224,224)
        '''
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        '''
        x.shape = (B, 3, H, W)
        self.proj(x).shape = (B, 96, H//4, W//4)
        self.proj(x).flatten(2).shape = (B, 96, H//4 * W//4)
        self.proj(x).flatten(2).transpose(1,2).shape = (B, H//4 * W//4, 96)
        '''
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x  # x.shape = (B, L, C)  L = H*W 这里的H,W是patches的分辨率,不是输入x的像素分辨率  C = 96

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

Part 2 : Swin Transformer Stage

Swin Transformer Stage 包括两个部分:Patch Merging 和 Swin Transformer Block

1. Patch Merging

Patch Merging的作用是分辨率减半,通道数加倍,类似于CNN的作用,在Transformer中实现Hierarchical

class PatchMerging(nn.Module):
    # PatchMerging作用: 输入分辨率减半,通道数翻倍
    # 这里的分辨率是转换成patches的分辨率,不是原图像素的分辨率
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        # 0::2 偶数
        # 1::2 奇数
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x) # 本来是C 后来变为4C,通过self.reduction(x)变回2C,最终由 C-->2C, 达到分辨率减半,通道数加倍的结果

        return x
        # 输入x.shape = (B, H*W, C),输出x.shape = (B, H/2*W/2, 2C) 分辨率减半,通道数翻倍
    
    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

参考链接:

Swin-Transformer 详解_swin transformer-CSDN博客

Swin Transformer 论文详解及程序解读 - 知乎

https://github.com/SwinTransformer/Swin-Transformer-Object-Detection

  • 16
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值