狗都能看懂的Swin Transformer的讲解和代码实现

1、Swin-Transformer介绍

自从ViT(Vision Transformer)出现之后,这种基于自注意力机制的视觉神经网络逐渐替代CNN称为主流backbone。但由于其参考的self-attention机制是来源于NLP(自然语言)领域的,而语言的文字是人类交流的结晶,本身语义信息高度集中。而CV(计算机视觉)不同,可能从图片上随机扣除一块区域,对整体的识别都没有影响。所以ViT中,很多计算是冗余的,并不需要全局的联系。Swin-Transformer基于ViT的结构进行改进,提出SW/W-MSA结构,有效降低计算量。

原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码:https://github.com/microsoft/Swin-Transformer
Pytorch实现代码:https://github.com/Runist/Swin-Transformer

2、模型整体框架

在讲解之前,如果没有了解过ViTself-attention的读者,建议还是先看一下前面的文章。相比于Vision Transformer来说,有两点不同:

  • Swin Transformer采用了类似CNN的层次化构建方法(Hierarchical feature maps),特征图会随着层数加深,逐渐下采样至4倍,8倍,16倍等。 但Vision Transformer的结构特征图经过一次16倍下采样之后就不变了。
  • Swin Transformer使用了Windows Multi-Head Self-Attention(W-MSA)的网络结构。这个结构限定了self-attention的范围,从全局进行qkv运算,变成只在给定窗口大小的范围内进行qkv计算,有效减少了计算量。但这样做也会隔绝了不同窗口之间的信息传递,所以作者又提出了Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过这个方法,可以让信息在相邻的窗口中传递。

在这里插入图片描述

网络的整体架构如下所示,由多个Swin Transformer Block堆叠而成:

在这里插入图片描述

  • Patch Partition和Linear Embeding就是对应ViT的Embedding层,即将图像分块,映射成self-attention中一个个token。在代码中是这么做的,4x4的像素小块为一个Patch,加上其有3个通道,每个Patch就有16x3=48个像素,这在代码中是由一个4x4的卷积进行处理的。那么图像的shape就从[H, W, 3]变成了[H/4, W/4, 48]。然后再通过Linear Emdbeding层对channel维度的数据做线性变换,由48变成C,即[H/4, W/4, 48]变成了[H/4, W/4, C]。
  • 模型通过Patch Merging进行下采样,每个Stage都会下采样4倍(除了第一个),同时在channel维度上翻倍。每个Stage都是堆叠Swin Transformer Block而成,这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以你会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用)。
  • 对于分类网络在代码中,还有LayerNorm、AvgPooling和一个全连接层组成,这个在图中没有体现。

3、Patch Mergeing详解

每个Stage中会经过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,每隔一个位置取一个像素,从而组成四个feature map。将这四个feature map在channel维度上进行concat拼接,再经过一个LayerNorm层。最后通过一个全连接层在feature map的channel维度上做线性变换,feature map的深度由4*C变成2*C。通过这个例子,可以看出,feature通过Patch Merging层之后,feature map的高和宽会减半,通道数翻倍。

在这里插入图片描述

实现代码:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        # *::2,每隔1个取一个值
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]

        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

4、W-MSA模块详解

引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左边是ViT用的Multi-head Self-Attention(MSA)模块,对于每个Patch都要和除了它自己之外的Patch去计算attention。在Windows Multi-head Self-Attention(W-MSA)模块中,我们会给定一个windows-size(下图windows-size=2),在一个windows内进行Self-Attention的计算。

在这里插入图片描述

这样就有效降低了计算量,具体相差多少呢?论文中给出公式:
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 ( h w ) 2 C \begin{aligned} & \Omega(MSA) = 4hwC^2 + 2(hw)^2C \\ & \Omega(W-MSA) = 4hwC^2 + 2M^2(hw)^2C \end{aligned} Ω(MSA)=4hwC2+2(hw)2CΩ(WMSA)=4hwC2+2M2(hw)2C

  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的通道数
  • M代表window-size,一般设置为7,是固定的。

这两个公式的推导,原文没有细说,我们简单计算一下。首先看一下Self-Attention的公式:
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}})V Attention(Q,K,V)=SoftMax(d QKT)V

MSA模块计算量

对于feature map的每个像素(或称为token,patch),都要通过 W q W_q Wq W k W_k Wk W v W_v Wv生成对应qkv。这里假设q,k,v的向量长度与feature map的channel数量C保持一致。那么对应所有像素生成Q的过程如下:
X h w × C ⋅ W q C × C = Q h w × C X^{hw \times C} \cdot W_q^{C \times C} = Q^{hw \times C} Xhw×CWqC×C=Qhw×C

  • X h w × C X^{hw \times C} Xhw×C为所有token拼接一起得到的矩阵(一共有hw个像素,每个像素的深度为C)
  • W q C × C W_q^{C \times C} WqC×C为生成query的变换矩阵
  • Q h w × C Q^{hw \times C} Qhw×C为所有像素通过 W q C × C W_q^{C \times C} WqC×C得到的query拼接后的矩阵

根据矩阵运算的计算量公式可以得到生成 Q Q Q的计算量为 h w × C × C hw \times C \times C hw×C×C,生成K和V的过程一样,同理都是 h w C 2 hwC^2 hwC2,那么总共是 3 h w C 2 3hwC^2 3hwC2。接下来 Q Q Q K T K^T KT相乘,对应计算量为 ( h w ) 2 C (hw)^2C (hw)2C
Q h w × C ⋅ K T ( C × h w ) = X h w × h w Q^{hw \times C} \cdot K^{T(C \times hw)} = X^{hw \times hw} Qhw×CKT(C×hw)=Xhw×hw
这里忽略除以 d \sqrt{d} d 以及softmax的计算量,假设得到 A h w × h w A^{hw \times hw} Ahw×hw,最后还要乘以 V V V,这里对应的计算量是 ( h w ) 2 C (hw)^2C (hw)2C
A h w × h w ⋅ V h w × C ) = X h w × C A^{hw \times hw} \cdot V^{hw \times C)} = X^{hw \times C} Ahw×hwVhw×C)=Xhw×C
那么对应单头的Self-Attention模块,总共需要 3 h w C 2 + ( h w ) 2 C + ( h w ) 2 C = 3 h w C 2 + 2 ( h w ) 2 C 3hwC^2 + (hw)^2C + (hw)^2C = 3hwC^2 + 2(hw)^2C 3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,在之前的文章中有进行过实验对比,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵 W O W_O WO的计算量 h w C 2 hwC^2 hwC2

所以总共加起来是: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C

W-MSA模块计算量

对于W-MSA模块首先要将feature map划分到多个窗口中,假设每个窗口的宽高都是M,那么总共会得到 h M × w M \frac{h}{M} \times \frac{w}{M} Mh×Mw个窗口,在每个窗口内使用多头注意力模块。刚刚计算高度h,宽度为w,通道数为C的feature map计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C,这里 h w hw hw替换为 M M M,代入公式:
4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2 + 2(M)^4C 4(MC)2+2(M)4C
又因为又 h M × w M \frac{h}{M} \times \frac{w}{M} Mh×Mw个窗口,则:
h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 ( M ) 2 h w C \frac{h}{M} \times \frac{w}{M} \times(4(MC)^2 + 2(M)^4C) = 4hwC^2 + 2(M)^2hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2(M)2hwC

所以W-MSA模块的计算量为: 4 h w C 2 + 2 ( M ) 2 h w C 4hwC^2 + 2(M)^2hwC 4hwC2+2(M)2hwC

5、SW-MSA详解

前面又说,采用W-MSA模块时,只会在自己窗口下进行Self-Attention的计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左边是上面说的W-MSA,右边用的是SW-MSA,两个模块是成对出现的。经过滚动的像素(token),其画窗口之后,框住的token就不同了,这样就使得不同窗口的信息有交流了。

在这里插入图片描述

这个图比较抽象,包括论文中出现的解析图,都画的原理不是很清晰。

在这里插入图片描述

我按照其他博主讲解的思路重新画了一个,我们先按照编号给每个窗口画上标记,左边A对应0区域,B对应3、6区域,C对应1、2区域。

在这里插入图片描述

首先将0、1、2移动至最后一行。

在这里插入图片描述

其次将3、6、0移动至最后一列。

在这里插入图片描述

移动完成之后,4是一个单独区域,5、4为一组,7、1为一组,8、6、2、0为一组。这样都是4x4的窗口,虽然我们这边解析看起来比较麻烦,但在代码中,只需要一个torch.roll()函数就可以实现。但在这里肯定有人回想,5、3本身是两个图像的边缘,混在一起计算不是乱了吗?一起计算也没问题,ViT也是全局计算的。但是Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。源码中具体的方法就是将不计算的位置元素减去100,让权重为0,不让其参与计算,这里就不细说了。

这里需要注意的是,在窗口数据进行滑动完之后,需要将数据还原回去,即挪回到原来的位置上。

对应的代码是:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        
        x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))

        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            # paper中,滑动的size是窗口大小的/2(向下取整)
            # torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。溢出的值在对角方向出现。即循环移动。
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

6、Relative Position Bias详解

关于相对位置偏执,论文里也没有细讲,就说了参考的哪些论文,然后说使用了相对位置偏执后给够带来明显的提升。根据原论文中的表4可以看出,在Imagenet数据集上如果不使用任何位置偏执,top-1为80.1,但使用了相对位置偏执(rel. pos.)后top-1为83.3,提升还是很明显的。
在这里插入图片描述

从论文中提供的公式,这个相对位置的偏执是加载softmax之前的:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T ( d ) + B ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt(d)} + B)V Attention(Q,K,V)=Softmax(( d)QKT+B)V
由于论文没有说明这个相对位置偏执编码如何计算出来的,这里根据源码解释一下。如图,假设我们现在有一个window-size=2的feature map,这里面如果用绝对位置来表示位置索引,左上角的token为(0, 0)右下角的token为(1, 1)其他位置以此类推。然后如果用相对位置表示,就会有4个情况,但分别都是以自己为(0, 0)计算其他token的相对位置。分别把4个相对位置展开,得到4x4的矩阵,如最下的矩阵所示。

在这里插入图片描述

请注意这里说的都是位置索引,并不是最后的位置编码。因为后面我们会根据相对位置索引去取对应位置的参数。取出来的值才是相对位置编码。源码中,作者还将二维索引给转成了一维索引。如果直接将行列相加,就变成一维了。但这样(0, 1)和(1, 0)得到的结果都是1,这样肯定不行。来看看源码的做法怎么做的:

首先,所有行列都加上M-1

在这里插入图片描述

其次将所有的行索引乘上2M-1

在这里插入图片描述

最后行索引和列索引相加,保证了相对位置关系,也不会出现0+1 = 1+0 的现象了。

在这里插入图片描述

刚刚也说了,之前计算的是相对位置索引,并不是实际位置偏执参数。真正使用到的数值需要从relative position bias table,这个表的长度是等于(2M-1)X(2M-1)的。在代码中它是一个可学习参数。

在这里插入图片描述

实现代码如下:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]

        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

7、模型详细配置参数

Swin Transformer的网络架构:

在这里插入图片描述

下图(表7)是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:

  • win. sz. 7x7表示使用的窗口(Windows)的大小
  • dim表示feature map的channel深度(或者说token的向量长度)
  • head表示多头注意力模块中head的个数

在这里插入图片描述

Swin Transformer 是一种基于 Transformer 的语义分割模型,它在分割任务中取得了突破性的性能。Swin Transformer 采用了跨层连接和位置编码等技巧,能够在保持模型深度的同时提高模型的效率和准确率,并且代码实现方便。 实现 Swin Transformer代码主要有以下几个步骤: 1. 安装 PyTorch 和相关工具包,如 torchvision 和 tqdm 等。 2. 定义 Swin Transformer 模型的主体结构,一般会定义一个 SwinEncoder 和 SwinDecoder 类。其中,SwinEncoder 用于提取特征,SwinDecoder 用于对特征进行分类和分割。 3. 定义模型的输入和输出,包括输入的图片尺寸、分类或分割的类别数等。 4. 实现模型的训练和推理过程。在训练时,需要定义损失函数、优化器和学习率等超参数,并通过反向传播算法不断更新模型的参数。在推理时,需要对输入的图片进行前向传播,得到预测结果。 5. 对训练的模型进行评价,比如计算准确率、召回率和 F1 值等评估指标,以检验模型的性能。 在代码实现过程中,还需要注意以下几点: 1. 为了加速训练,可以采用混合精度训练技巧,即使用 float16 精度计算梯度和参数更新,从而减少显存占用和计算时间。 2. 为了提高模型的泛化能力,可以采用数据增强技巧,比如随机裁剪、随机翻转等,从而增加训练数据的多样性。 3. Swin Transformer 中跨层连接和位置编码的实现比较特殊,需要对代码进行细致的理解和调试。 总之,实现 Swin Transformer 的关键在于理解模型的结构和原理,并实现对应的代码逻辑。只有不断地优化和调试,才能最终得到高效、准确的模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值