音视频开发之旅(91)-Swin-Transformer论文解读与源码分析

目录

1. 背景和问题

2. Swin Transformer模型结构

3.Patch Merging

4. Window Attention 

5. Shifted Window Attention

6. 实验效果

7. 源码分析

8. 资料

一. 背景和问题

上一篇我们学习了Vision Transformer的原理,把图像分割为固定大小的patch,通过线性变换得到PatchEmbedding,然后将图像的patch embeddings送入transformer的Encoder进行特征提取。通过实验也验证了在大数据集上的效果要优于CNN。但是存在如下几个问题:

1、VIT中预训练的默认图像分辨率为224*224,对于大分辨率,随着分辨率的增加,VIT的自注意力机制计算复杂度呈平分级增加。

2、VIT采用固定大小的patch,难以捕捉不同尺寸的图像特征。

3、VIT处理不同分辨率的图像要单独预训练或者进行插值处理,因为PostionEmbedding和切割的patch个数有关,每个patch分辨率固定,随着分辨率的增加,patch也会相应增加。

4、VIT主要关注全局信息,可能忽略了局部细节。

可以借鉴CNN的层次化结构,采用特征金字塔融合不同分辨率的特征,实现关注全局和局部.但是计算复杂度的问题该如何解决呐?这就是SwinTransformer要解决的问题.

二. Swin Transformer模型结构

Swin Transformer基于VIT的思想,创新性的引入了WindowAttention,让self-attention的计算限制在窗口内,使得计算复杂度从 O(n²) 降低到 O(n),ShiftedWindowAttention滑动窗口机制让模型能够学习到不同窗口的信息。同时也借鉴了传统CNN的层次化结构,对图像进行下采样,是的模型能够关注全局和局部信息。Swin Transformer已成为CV领域通用的backbone

图片

Swin Transformer相比VIT,采用局部窗口的自注意力计算以及层次化特征图结构,在处理大图像上更加高效。

图片

模型采取层次化的特征图和局部窗口注意力机制设计,共包含4个Stage,Stage由Patch Merging和Swin Transformer Block组成,Patch Merging模块会缩小输入特征图的分辨率,增加通道数,像CNN一样逐层扩大感受野,

SwinTransformerBlock是模型的核心,每个SwinTransformerBlock由W-MSA(非移动局部窗口注意力机制)和SW-MSA(移动窗口注意力机制)组成.W-MSA通过局部窗口注意力机制降低了计算的复杂度, SW-MSA通过滑动窗口自注意力机制增加了窗口之间的信息交互.

这种架构设计能够适应不同类型的视觉任务,eg:图像分类 目标检测和语义分割等.

论文中给出了四种不同深度的参数配置,如下表所示.eg:Swin-T  4个stage中分别堆叠(2,2,6,2)个PatchMerging和SwinTrasformerBlock.

图片

三. Patch Merging

Patch Embedding和VIT的作用一样,对图像且分为n个patch,然后对每个patch进行特征提取,在代码实现上有些区别,VIT中采用线性变换,SwinTransformer采用CNN,具体见后面的代码解析.

Patch Merging

在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

图片

图片来自:图解Swin Transformer

Swin Transformer Block 这部分是整个设计的核心,包含了论文中的很多知识点,涉及到相对位置编码、mask、window self-attention、shifted window self-attention,下面我们重点学习.

四. Window Attention 

VIT是基于全局来计算注意力的,随着分辨率的增加,计算复杂度平方级增加。而 Swin Transformer 则将注意力的计算限制在每个窗口内,进而减少了计算量, 这就是WindowAttention.

假设一张图像可以切割为hxw个patches,每个窗口包括MxM个pathes

原始的self-Attention计算复杂度为(hw)的平方;而Swin Transformer的Window self-attention,在每个窗口内计算self-attention,计算复杂度为M的平方*hw, Window self-attention将计算复杂度从平方关系降低到线性关系.

图片

和传统的Attention相比,加入了相对位置偏置B (这里的相对位置偏置是如何计算的呐? 后面会有答案)

五. Shifted Window Attention

为了使不同window之间进行信息交互,Swin Transformer引入了shifted windowAttention

如下图所示,左边是不同窗口重叠交互的Window Attention,而右边是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素,实现不同window之ian的信息交互,但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。

图片

为了避免Window增强, 作者巧妙的采用对特征图进行移位,分别在水平和垂直维度对内容进行roll翻滚(即上/左面的行/列 roll到最下/右面的行/列),把window降低到4个, 但在计算新的Window的局部注意力机制时,原本不同window的内容不应该参与到QKV的计算,为此作者又引入了Mask的方式,这个确实太叼了,代码实现也很巧妙. 通过relative_position_bias_table(相对位置偏移表)relative_position_index(相对位置索引)来计算

图片

图片

图片来自:图解Swin Transformer

整体计算公式如下

图片

公式1: 计算l层的阶段输出z(l), 首先对前一层z(l-1)进行LayerNorm,然后应用W-MSA局部窗口注意力机制,最后再加上残差连接z(l-1),避免梯度消失

公式2:计算l层的输出z(l),首先对公式1的阶段输出z(l)进行LayerNorm,然后应用MLP多层感知机,最后再加上残差连接z(l)

公式3: 计算l+1层的阶段输出z(l+1),首先对公式2的输出z(l)进行LayerNorm,然后应用SW-MSA(移动窗口注意力机制),最后再加上残差连接

公式4: 计算l+1层的输出z(l+1),首先对公式3的输出进行LayerNorm,然后应用MLP,最后加上残差连接

相对位置索引与相对位置偏置

相对位置偏置就是下面计算WindowAttention中的B.

图片

论文中没有关于相对位置索引的详细介绍,代码中看得也是一头雾水. B站视频Swin-Transformer网络结构详解-视频 解释的相当清楚,再结合源码分析有一种醍醐灌顶感觉.

# define a parameter table of relative position bias 相对位置偏移表        self.relative_position_bias_table = nn.Parameter(            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH        # get pair-wise relative position index for each token inside the window        coords_h = torch.arange(self.window_size[0])        coords_w = torch.arange(self.window_size[1])        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2                #为了方便后续计算,相对坐标都加上偏移量,shift to start from 0        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0        relative_coords[:, :, 1] += self.window_size[1] - 1
        #后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x与y坐标相加转换为一维偏移的时候,它们的偏移量是相等的,所以需要对其做乘法操作,进行区分        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1                #将最后一个维度进行求和. 计算相对位置索引,它将用于索引相对位置偏移表,以便在self-attention中为每个元素分配一个特定的偏置值        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

图片

偏移从0开始,行列都加上M-1

图片

接着将所有的行标都乘上2M-1。

图片

最后将行标和列标进行相加。即保证了相对位置关系,又不会出现 0 + ( − 1 ) = ( − 1 ) + 0 的问题了.太牛了

图片

通过上面计算的相对位置索引,在相对位置索引表(relative_position_bias_table)中差表得到相对位置偏置的值B

图片

感谢霹雳吧啦Wz大佬,真是太透彻了

六. 实验结果

图片

通过上表可以看出,不论是是否有预训练,SwinTransform的准确率都高于VIT,且计算量要小于VIT

七. 源码分析

源码地址https://github.com/microsoft/SwinTransformer/blob/main/models/swin_transformer.py

7.1 PatchEmbed

class PatchEmbed(nn.Module):    r"""     对输入的图像进行切分,对每个patch进行特征提取作为Embedding        Image to Patch Embedding
    Args:        img_size (int): Image size.  Default: 224.        patch_size (int): Patch token size. Default: 4.        in_chans (int): Number of input image channels. Default: 3.        embed_dim (int): Number of linear projection output channels. Default: 96.        norm_layer (nn.Module, optional): Normalization layer. Default: None    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):        super().__init__()        img_size = to_2tuple(img_size)        patch_size = to_2tuple(patch_size)        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]        self.img_size = img_size        self.patch_size = patch_size        self.patches_resolution = patches_resolution        self.num_patches = patches_resolution[0] * patches_resolution[1]
        self.in_chans = in_chans        self.embed_dim = embed_dim
        #这里使用CNN进行特征提取作为Embedding,kernel_size为4*4,步长为4,输入channel为3,输出channel为96        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)        if norm_layer is not None:            self.norm = norm_layer(embed_dim)        else:            self.norm = None
    def forward(self, x):        B, C, H, W = x.shape        # FIXME look at relaxing size constraints        assert H == self.img_size[0] and W == self.img_size[1], \            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        """        x.shape = (B,3,H,W)        self.proj(x).shape = (B,96,H//4,W//4)        x = self.proj(x).flatten(2).shape = (B,96,H//4 * W//4)        self.proj(x).flatten(2).transpose(1, 2).shape = (B,H//4 * W//4,96)        """            x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C        if self.norm is not None:            x = self.norm(x)        return x

7.2 PatchMerging

class PatchMerging(nn.Module):    r""" Patch Merging Layer.    Patch Merging是Swin Transformer Stage的一部分,    作用是分辨率减半,通道数加倍.
    Args:        input_resolution (tuple[int]): Resolution of input feature. 输入特征图(patches)的分辨率        dim (int): Number of input channels. 维度        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm    """
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):        super().__init__()        self.input_resolution = input_resolution        self.dim = dim        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)        self.norm = norm_layer(2 * dim)
    def forward(self, x):        """        x: B, H*W, C        """        H, W = self.input_resolution        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
        x = x.view(B, H, W, C)
        #将输入特征图分割为四部分,每部分都是(B,H/2,W/2,C)的形状        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C        #最后一个维度C进行拼接,输出 (B,H/2,W/2,4*C)        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C        #重新塑造形状,把 HW变为一维        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C        #减少通道数,4*C变为2*C        x = self.reduction(x)        x = self.norm(x)
        return x

7.3 WindowAttention​​​​​​​

class WindowAttention(nn.Module):    r""" Window based multi-head self attention (W-MSA) module with relative position bias.    It supports both of shifted and non-shifted window.
    Args:        dim (int): Number of input channels.        window_size (tuple[int]): The height and width of the window.        num_heads (int): Number of attention heads.        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0        proj_drop (float, optional): Dropout ratio of output. Default: 0.0    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()        self.dim = dim        self.window_size = window_size  # Wh, Ww        self.num_heads = num_heads        head_dim = dim // num_heads        self.scale = qk_scale or head_dim ** -0.5
        # define a parameter table of relative position bias 相对位置偏移表        self.relative_position_bias_table = nn.Parameter(            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # get pair-wise relative position index for each token inside the window        coords_h = torch.arange(self.window_size[0])        coords_w = torch.arange(self.window_size[1])
        """        torch.meshgrid生成网格矩阵,torch.stack在新的维度堆叠起来,torch.flatten进行展平为一维张量                eg:self.window_size =(2,2)        torch.stack(torch.meshgrid([coords_h, coords_w]))的输出为                tensor([[[0, 0, 0],                            [1, 1, 1],                            [2, 2, 2]],
                        [[0, 1, 2],                            [0, 1, 2],                            [0, 1, 2]]])
        torch.flatten(coords, 1)的输出为:            tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],[0, 1, 2, 0, 1, 2, 0, 1, 2]])        """        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2                #为了方便后续计算,相对坐标都加上偏移量,shift to start from 0        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0        relative_coords[:, :, 1] += self.window_size[1] - 1
        #后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x与y坐标相加转换为一维偏移的时候,它们的偏移量是相等的,所以需要对其做乘法操作,进行区分        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1                #将最后一个维度进行求和. 计算相对位置索引,它将用于索引相对位置偏移表,以便在self-attention中为每个元素分配一个特定的偏置值        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)        self.attn_drop = nn.Dropout(attn_drop)        self.proj = nn.Linear(dim, dim)        self.proj_drop = nn.Dropout(proj_drop)
        trunc_normal_(self.relative_position_bias_table, std=.02)        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x, mask=None):        """        Args:            x: input features with shape of (num_windows*B, N, C)            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None        """        B_, N, C = x.shape        '''        x.shape = (num_windows*B, N, C)        self.qkv(x).shape = (num_windows*B, N, 3C)        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)        '''        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        q = q * self.scale        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww        attn = attn + relative_position_bias.unsqueeze(0)            #为了解决不重叠窗口之间没有关联的问题,采用shifted window方法. 以(M/2,M/2)向下移动的窗口重新对原图进行分割,并将之前没有联系的patch划到新窗口.        #但这带来了窗口增加的问题(4个增加到了9个),为了避免窗口增加导致的额外计算量并保证不重叠窗口之间有关联,论文提出了cyclic shift方案        #为了保证shifted window self-attention计算的正确性,只能计算相同子窗口的self-attention,不同子窗口的self-attention结果要归0,        # 不同编码位置间计算的self-attention结果通过mask加上-100,在Softmax计算过程中,Softmax(-100)无线趋近于0,达到归0的效果        if mask is not None:            nW = mask.shape[0]            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)            attn = attn.view(-1, self.num_heads, N, N)            attn = self.softmax(attn)        else:            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)        x = self.proj(x)        x = self.proj_drop(x)        return x

7.4 SwinTransformerBlock

class SwinTransformerBlock(nn.Module):    r""" Swin Transformer Block.
    Args:        dim (int): Number of input channels.        input_resolution (tuple[int]): Input resulotion.        num_heads (int): Number of attention heads.        window_size (int): Window size.        shift_size (int): Shift size for SW-MSA.        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.        drop (float, optional): Dropout rate. Default: 0.0        attn_drop (float, optional): Attention dropout rate. Default: 0.0        drop_path (float, optional): Stochastic depth rate. Default: 0.0        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False    """
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,                 fused_window_process=False):        super().__init__()        self.dim = dim        self.input_resolution = input_resolution        self.num_heads = num_heads        self.window_size = window_size        self.shift_size = shift_size        self.mlp_ratio = mlp_ratio        if min(self.input_resolution) <= self.window_size:            # if window size is larger than input resolution, we don't partition windows            self.shift_size = 0            self.window_size = min(self.input_resolution)        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        self.norm1 = norm_layer(dim)        self.attn = WindowAttention(            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()        self.norm2 = norm_layer(dim)        mlp_hidden_dim = int(dim * mlp_ratio)        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        if self.shift_size > 0:            # calculate attention mask for SW-MSA            H, W = self.input_resolution            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1            h_slices = (slice(0, -self.window_size),                        slice(-self.window_size, -self.shift_size),                        slice(-self.shift_size, None))            w_slices = (slice(0, -self.window_size),                        slice(-self.window_size, -self.shift_size),                        slice(-self.shift_size, None))            cnt = 0            for h in h_slices:                for w in w_slices:                    img_mask[:, h, w, :] = cnt                    cnt += 1
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))        else:            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)        self.fused_window_process = fused_window_process
    def forward(self, x):        H, W = self.input_resolution        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"
        shortcut = x        x = self.norm1(x)        x = x.view(B, H, W, C)
        # cyclic shift        if self.shift_size > 0:            if not self.fused_window_process:                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))                # partition windows                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C            else:                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)        else:            shifted_x = x            # partition windows            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
        # W-MSA/SW-MSA        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        # merge windows        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        # reverse cyclic shift        if self.shift_size > 0:            if not self.fused_window_process:                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))            else:                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)        else:            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C            x = shifted_x        x = x.view(B, H * W, C)        x = shortcut + self.drop_path(x)
        # FFN        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

7.5 BasicLayer  (即一个stage,包含SwinTransformerBlock PatchMerging模块)

class BasicLayer(nn.Module):    """    BasicLayer为一个Stage,包含SwinTransformerBlock和PatchMerging模块    """
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,                 fused_window_process=False):
        super().__init__()        self.dim = dim        self.input_resolution = input_resolution        self.depth = depth
        # build blocks        self.blocks = nn.ModuleList([            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,                                 num_heads=num_heads, window_size=window_size,                                 shift_size=0 if (i % 2 == 0) else window_size // 2,                                 mlp_ratio=mlp_ratio,                                 qkv_bias=qkv_bias, qk_scale=qk_scale,                                 drop=drop, attn_drop=attn_drop,                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,                                 norm_layer=norm_layer,                                 fused_window_process=fused_window_process)            for i in range(depth)])
        # patch merging layer        if downsample is not None:            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)        else:            self.downsample = None
    def forward(self, x):        for blk in self.blocks:            x = blk(x)        if self.downsample is not None:            x = self.downsample(x)        return x

7.6 SwinTransformer

class SwinTransformer(nn.Module):    r""" Swin Transformer网络    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,                 use_checkpoint=False, fused_window_process=False, **kwargs):        super().__init__()
        self.num_classes = num_classes        self.num_layers = len(depths)        self.embed_dim = embed_dim        self.ape = ape        self.patch_norm = patch_norm        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))        self.mlp_ratio = mlp_ratio
        # split image into non-overlapping patches        self.patch_embed = PatchEmbed(            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,            norm_layer=norm_layer if self.patch_norm else None)        num_patches = self.patch_embed.num_patches        patches_resolution = self.patch_embed.patches_resolution        self.patches_resolution = patches_resolution
        # absolute position embedding        if self.ape:            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)
        # stochastic depth        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        # build layers        self.layers = nn.ModuleList()        for i_layer in range(self.num_layers):            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),                               input_resolution=(patches_resolution[0] // (2 ** i_layer),                                                 patches_resolution[1] // (2 ** i_layer)),                               depth=depths[i_layer],                               num_heads=num_heads[i_layer],                               window_size=window_size,                               mlp_ratio=self.mlp_ratio,                               qkv_bias=qkv_bias, qk_scale=qk_scale,                               drop=drop_rate, attn_drop=attn_drop_rate,                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],                               norm_layer=norm_layer,                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,                               use_checkpoint=use_checkpoint,                               fused_window_process=fused_window_process)            self.layers.append(layer)
        self.norm = norm_layer(self.num_features)        self.avgpool = nn.AdaptiveAvgPool1d(1)        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)
    def _init_weights(self, m):        if isinstance(m, nn.Linear):            trunc_normal_(m.weight, std=.02)            if isinstance(m, nn.Linear) and m.bias is not None:                nn.init.constant_(m.bias, 0)        elif isinstance(m, nn.LayerNorm):            nn.init.constant_(m.bias, 0)            nn.init.constant_(m.weight, 1.0)
    @torch.jit.ignore    def no_weight_decay(self):        return {'absolute_pos_embed'}
    @torch.jit.ignore    def no_weight_decay_keywords(self):        return {'relative_position_bias_table'}
    def forward_features(self, x):        x = self.patch_embed(x)        if self.ape:            x = x + self.absolute_pos_embed        x = self.pos_drop(x)
        for layer in self.layers:            x = layer(x)
        x = self.norm(x)  # B L C        x = self.avgpool(x.transpose(1, 2))  # B C 1        x = torch.flatten(x, 1)        return x
    def forward(self, x):        x = self.forward_features(x)        x = self.head(x)        return x

存在一定的不足:

  1. 模型规模不够大

  2. 预训练与下游任务图片分辨率和窗口大小不适配

在Swin Transformer V2论文给出了解决方案,有兴趣可以进一步研究.

Swin Transformer 2.0 使得模型规模更大并且能适配不同分辨率的图片和不同尺寸的窗口

八. 资料

1.Swin Transformer 论文:https://arxiv.org/pdf/2103.14030

2.Swin Transformer V2论文:https://arxiv.org/pdf/2111.09883

https://github.com/microsoft/Swin-Transformer/tree/main/models

3.图解Swin Transformer https://zhuanlan.zhihu.com/p/367111046

4.Swin-Transformer网络结构详解-视频 https://www.bilibili.com/video/BV1pL4y1v7jC

5.Swin-Transformer网络结构详解-文章 https://blog.csdn.net/qq_37541097/article/details/121119988

6.https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/swin_transformer

7.论文详解:Swin Transformer https://zhuanlan.zhihu.com/p/430047908

8.CV+Transformer之Swin Transformer https://zhuanlan.zhihu.com/p/361366090

9.AI大模型系列之三:Swin Transformer 最强CV图解 https://blog.csdn.net/Peter_Changyb/article/details/137183056

10.Swin Transformer 论文详解及程序解读 https://zhuanlan.zhihu.com/p/401661320

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值