swin transformer论文及代码学习

论文    源码

Overall Architecture

Image

输入image的大小为三维矩阵:H W 3。H为Height,W为Width,3为通道channel,这里指的是RGB。图中只以一个image为例,也就是batch_size = 1。

Patch Partition

一张图片读入后表示为像素矩阵,需要先对图片进行patch partition处理,将图片的最小单位从像素转变为patch。论文中所给的示例为一个patch由4*4个pixel构成,即patch partition模块用包含4*4个像素的patch来对像素矩阵进行分割,并一个patch中的像素值合成一个向量。输入的像素矩阵经过处理后变为的三维矩阵,其中H/4 * W/4表示patch的数量,48为channel,由3*4*4得来。

Linear Embedding

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension(denoted as C).C default 96.

代码中PatchEmbed()类中包含了Patch Partition 和Linear Embedding两个模块。具体代码含义见下面注释。

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4. 4 pixel
        in_chans (int): Number of input image channels. Default: 3. RGB
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        #img_size -> tuple(img_size,img_size) 即(224,224)
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        #图片中的patch个数
        self.num_patches = patches_resolution[0] * patches_resolution[1]
        #input channel default RGB:3
        self.in_chans = in_chans
        #output channel
        self.embed_dim = embed_dim
        #in_chans:3 out_chans:96 即输入一个图片的三个通道 经过4*4的卷积核和步长为4后输出channel为96
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        #默认经过embedding后进行标准化norm layer
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # batch channel height width
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        #判断输入图片格式是否符合模型
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        #将图片进行embed flatten将数组从括号中的维度开始将后面的所有数降为一维 
        # 这里x为B C PH PW 4维
        #经过flatten变为B C PH*PW 3维 再经过transpose调换维度 变为B PH*PW C
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        #标准化
        if self.norm is not None:
            x = self.norm(x)
        return x
    #理论计算复杂度
    def flops(self):
        #用patch的数量来进行计算
        Ho, Wo = self.patches_resolution
        # 图片patch的数量 网络输入输出通道 每个patch内含有的pixel
        flops = Ho * Wo * (self.embed_dim * self.in_chans) * (self.patch_size[0] * self.patch_size[1])
        #如果未进行标准化 则再加上
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

Patch Merging

实现了patch合并,使得channel*4,在经过卷积网络进行降维,使得channel/2。

The first patch merging layer concatenates the features of each group of 2*2 neighboring patches,and applies a linear layer on the 4C-dimensional concatenated features.This reduces the number of tokens by a multiple of 2*2 = 4(2*downsampling of resolution),and the output dimension is set to 2C.

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature. feature map的size
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        #input channel:4*dim output channel:2*dim 降维
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        #检验输入是否符合模型
        assert L == H * W, "input feature has wrong size"
        #判断高 宽是否为偶数
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
        #reshape
        x = x.view(B, H, W, C)
        #0::2 表示从0开始step为2,一直取完
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        #cat为接合张量 -1表示最内侧的那一维 如0,1,2,3 则-1表示3 这里表示xi张量按照C这一维度进行接合
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        # 标准化
        x = self.norm(x)
        x = self.reduction(x)
        return x
    #一些额外的信息
    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"
    #计算复杂度
    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

Swin Transformer Block

LN  LayerNorm

A LayerNorm layer is applied before each MSA module and each MLP,and a residual connection is applied after each module.

MLP  Muti-Layer Perception

就是一个两层fc网络

class Mlp(nn.Module):
    """
    Muti-Layer Perception ,MLP
    就是一个两层fc网络
    in_features:
    hidden_features:
    out_features:
    act_layer: activation function
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        """
        GELU(x)=x∗Φ(x)
        """
        super().__init__()
        #若未输入out_features hidden_features则输入输出维度一样
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        #两个fc网络
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        #随机赋0 防止过拟合
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        #随机赋0
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

W-MSA  and  SW-MSA

W-MSA and SW-MSA are multi-head self attention modules with regular and shifed windowing configurations,respectively.

这里的multi-head self attention 请见其他文章。下面主要介绍window_partition,shifted window partition and window_attention

这一部分借鉴

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

若水菱花

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值