Swin Transformer详解

本文深入解析Swin Transformer架构,包括基于窗口/移动窗口的自注意力机制,以及在图像分类和目标检测任务中的实现。Swin Transformer通过限制注意力计算在窗口内降低复杂度,同时通过移动窗口实现不同区域的信息交流,有效提升模型性能。
摘要由CSDN通过智能技术生成

原创:余晓龙

“Swin Transformer: Hierarchical Vision Transformer using Shifted Window”是微软亚洲研究院(MSRA)发表在arXiv上的论文,文中提出了一种新型的Transformer架构,也就是Swin Transformer。本文旨在对Swin Transformer架构进行详细解析。

一、Swin Transformer网络架构

整体的网络架构采取层次化的设计,共包含4个stage,每个stage都会缩小输入特征图的分辨率,类似于CNN操作逐层增加感受野。对于一张输入图像224x224x3,首先会像VIT一样,把图片打成patch,这里Swin transformer中使用的patch size大小为 4x4,不同于VIT中使用的大小为16x16。经过Patch Partition,图像的大小会变成56 x 56 x 48, 其中48为 (4x4x3)3 为图片的rgb通道。打完patch之后会经过Linear Embedding,这里的主要目的是为了把向量的维度变成我们预先设定好的值,即可以满足transformer可以输入的值。在Swin-T网络中,这里C的大小为96,得到的网络输出值为56x56x96。之后经过拉直,序列长度变成3136 x 96。其代码如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


if __name__ == '__main__':
    x = torch.randn(8, 3, 224, 224)
    x, W, H = PatchEmbed()(x)
    print(x.size())  # torch.Size([8, 3136, 96])
    print(W)  # 56
    print(H)  # 56

Swin transformer引入了基于窗口的自注意力计算,每个窗口为 7x7=49个patch。如果想要有多尺度的特征信息,就需要构建一个层级式的transformer,类似卷积神经网络里的池化操作,Patch Merging用于缩小分辨率,调整通道数,完成层级式的设计。这里每次的降采样为2,在行和列方向每隔一个点选取元素,之后拼接在一起展开。

相当于在空间上的维度去换到了更多的通道数,维度变成4C,之后在C的维度上利用全连接层,将通道数的大小变成2C,经过上述操作之后网络输出的大小变为28 x 28 x 192。之后经过拉直,序列长度变成784 x 192。后面的stage3、stage4同理。最终的网络输出的大小变为7 x 7 x 768。之后经过拉直,序列长度变成49 x 768。代码如下:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x


if __name__ == '__main__':                            
    x = torch.randn(8, 3, 224, 224)                   
    x, H, W = PatchEmbed()(x)                         
    # print(x.size())  # torch.Size([8, 3136, 96])    
    # print(W)  # 56                                  
    # print(H)  # 56                                  
                                                      
    x = PatchMerging(dim=96)(x, H, W)                 
    print(x.size())       # torch.Size([8, 784, 192]) 

基于窗口/移动窗口的自注意力

由于全局的自注意力计算会导致平方倍的复杂度,因此作者提出了基于窗口的自注意力机制。原来的图片会被平均分成一些没有重叠的窗口,以第一层的输入为例,尺寸大小为56 x 56 x 96。

在每一个小方格中会有7x7=49个patch,因此大的特征图可以分为 56 / 7 x 56 / 7 = 8 x 8 个窗口。

基于窗口的自注意力机制与基于全局的自注意力机制复杂度对比:

以标准的多头自注意力为例, 对于一个输入,自注意力首先会将它变成q, k, v三个向量,之后得到的q, k 相乘得到attention,在有了自注意力之后后和得到的v进行相乘,相当于做了一次加权,最后因为这是使用了多头自注意力机制,还会经过一个projection layer,这个投射层就会把向量的维度投射到我们想要的那个维度,如下图:

公式一 :
3 h w c 2 + ! ( h w ) 2 c + ( h w ) 2 c + h w c 2 3hwc^{2} + ! (hw)^{2}c + (hw)^{2}c + hwc^{2} 3hwc2+!(hw)2c+(hw)2c+hwc2

公式二:基于窗口的自注意力复杂度 一个窗口大小 M x M 代入公式一得

4 M 2 c 2 + 2 M 4 c 4M^{2}c^{2} + 2M^{4}c 4M2c2+2M

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值