通俗易懂的Swin Transformer讲解,有代码注释(还包含Transformer和ViT的理解)

Transformer ( 常规 Swin Vision) 的理解

首先简单介绍transformer和vision transformer

Transformer

在这里插入图片描述
这里主要讲的是multihead attention

维度计算:Transformer详解:各个特征维度分析推导 | Hello World 💓 (troublemeeter.github.io)

m: Embedding size(maxlen1)就是设置维度,也就是我把文本输进去之后,每个单词 / 字都是一个token,是模型输入基本单元,将token转化成计算机所能理解的序列, 映射的结果就是该token对应的embedding,它的长度一般设置成512。该操作是为了统一维度,如果长度不够padding = 0

d是 d m o d e l d_{model} dmodel​, 是位置向量的维度, 输出后为d

h是多头的头数( head )

input: [batchsize, maxlen, d], 即 X ∈ R m × d X \in \mathbb{R}^{m \times d} XRm×d

  • Q, K, V含义

    • Q , K , V ∈ R m × d Q, K, V \in \mathbb {R}^{m \times d} Q,K,VRm×d

      • 严格上讲, `

      •  : [maxlen_q, d]
        K = V : [maxlen_k, d]
        
      • 但是我也不知道为什么在transformer中要强调一下QKV的maxlen不一样, 但是还要说Q = K = V, 因为基本上QKV维度都是一样的, 包括在transformer本身也是如此

    • Q和K严格来说是用来计算字符的相似度的, 正常情况下是形成一个下三角矩阵(CLIP), 也就是说,两个向量点乘表示两个向量的相似度, 或者说是attention score. 实际上,QKV在物理意义上是一致的,都是同一个句子不同token组成的矩阵, 矩阵的每一行都是一个token的embedding. 假设一个句子"Hello, how are you?"长度是6,embedding维度是300,那么Q,K,V都是(6, 300)的矩阵

    • 如果QK相同,下三件矩阵泛化能力差, 因为向量倍投影到同一个空间. 这种矩阵在对V提纯的时候 ( Q K T QK^T QKT)结果不好.W

    • QKV不同, embedding矩阵被投影到不同的空间中. 泛化能力好

    • 其实我自己的解释比较奇怪, 我认为第一个矩阵对第二个矩阵做矩阵乘法,某种程度上完成了单词前后之间的关系计算, 单纯的三角矩阵中上三角和下三角都是一致的, 可能做不到? 我也不是很清楚
      在这里插入图片描述

首先, 对QKV进行线性变换, 最开始的QKV都是 X ∈ R B × m × d X \in\mathbb{R}^{\mathcal{B} \times m \times d} XRB×m×d, d_h = d / h, 还需要考虑多头注意力机制.
Q i ∗ = Q W Q , W i Q ∈ R d × d h , Q i ∗ ∈ R B × m × d h K i ∗ = K W K , W i K ∈ R d × d h , K i ∗ ∈ R B × m × d h V i ∗ = V W V , W i V ∈ R d × d h , V i ∗ ∈ R B × m × d h Q^*_i = QW_Q, \quad W_i^Q \in \mathbb{R}^{d \times d_h}, \quad Q^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h}\\ K^*_i = KW_K, \quad W_i^K \in \mathbb{R}^{d \times d_h}, \quad K^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h} \\ V^*_i = VW_V, \quad W_i^V \in \mathbb{R}^{d \times d_h}, \quad V^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h} \\ Qi=QWQ,WiQRd×dh,QiRB×m×dhKi=KWK,WiKRd×dh,KiRB×m×dhVi=VWV,WiVRd×dh,ViRB×m×dh

对每个头( head i \text{head}_i headi)开始计算自注意力
f = Attention ( Q , K , V ) = SoftMax ( Q K T d k ) V f=\text{Attention}(Q, K, V)=\text{SoftMax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V f=Attention(Q,K,V)=SoftMax(dk QKT)V
这里 d k = 512 d_k = 512 dk=512.
实际上, 应该是 f = Attention ( q , K , V ) = SoftMax ( q K T d k ) V f=\text{Attention}(q, K, V)=\text{SoftMax}\left(\frac{q K^T}{\sqrt{d_k}}\right) V f=Attention(q,K,V)=SoftMax(dk qKT)V​, 也就是说,对每一个q单独计算,然后重新组合. 所以这里用了矩阵计算

softmax计算的是字符对字符的相似性, 所以经过 Q K T QK^T QKT后得到的矩阵维度为m * m, 得到相似性后, 与V矩阵相乘, 得到的维度为m * d, 乘上V的目的是通过相似性(评分)让V中每个token的向量在每个维度上(每一列)上,都会对其他token做出调整

Vision Transformer

核心思想:分成patch后每个patch都是一个token

文本是二维的, 但是图像是三维的, 所以ViT把图片打成patch.

在这里插入图片描述

对于ViT来说, 把图片打成固定大小的patch, 然后flatten patch, 作为图像的token. 假设图像为224 * 224 * 3, patch大小为16 * 16, 那么该图像一共有14 * 14个patch, 也就是196个token.

首先把patch打成token. patch展平就是16 * 16 * 3 = 768. 我们需要把他通过线性映射变成想要的embedding size, 假设我们想要的是768, 那么既可以通过线性变换, 也可以加一个卷积核kernel_size = 16 * 16, stride = 16, channel = 768, 总之最后变成了196 * 768( 卷积公式: (224 - 16 + 2 * 0) / 16 + 1 = 14, 变成[3, 224, 224] -> [758, 14, 14])

接下来就是输入到多头自注意力机制.
后面计算softmax 那么这里的seq_len * seq_len = 196(14 * 14) * 196

❓问题就是, 图像分辨率增大, patch大小不变, 然而, token数量变多了,如果是448 * 448, 那么patch数量为28 * 28, 这个增长数量还是蛮多的, 平方级复杂度.
所以, 提出来了swin transformer
在这里插入图片描述

  • 计算复杂度
    • 首先是QKV三次线性映射

      • 打成patch的结果Q, K, V: [h*w, d(c)] = [14*14, 768]与权重W: [d(c), d_h], 则复杂度为3 * h * w * c * c, -> [h*w, d_h]
    • softmax

      • Q K T Q K^T QKT[h*w, d_h], [d_h, h*w], 复杂度为h * w * h * w * c, -> [h*w, h*w]
      • Q K T × V QK^T \times V QKT×V​, 则[h*w, h*w], [h*w, d_h(c)], 复杂度为h * w * h * w * c, -> [h*w, d_h(c)]
    • 线性映射

      • softmax输出结果为[h*w, d_h(c)], 需要经过一个w: [d_h(c), d]的线性变换, 则复杂度为h * w * c * c
    • 则复杂度为
      Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(\mathrm{MSA})=4 h w C^2+2(h w)^2 C Ω(MSA)=4hwC2+2(hw)2C

Swin Transformer

整体介绍

核心思想: 分成多个window后变成单独图片, 对每个window计算自注意力
在这里插入图片描述
为了解决ViT弊端, 提出了window概念.

把图片分成8 * 8个窗口, 每个window里面有7 * 7个patch, patch大小为4 * 4, 这样的话, 多个图像的多个window做mult-head attention, 后面计算softmax的seq_len * seq_len = 49(7 * 7) * 49.

值得注意的是, swin transformer的定义窗口的时候就是窗口大小定义. 假设窗口大小为M = , 则窗口数量为H / M * W / M个窗口.

  • 计算复杂度
    • 对于单个窗口计算复杂度为:
      Ω ( M S A ) = 4 m m C 2 + 2 ( m m ) 2 C \Omega(\mathrm{MSA})=4 mm C^2+2(mm)^2 C Ω(MSA)=4mmC2+2(mm)2C

    • 那么乘上窗口数量
      Ω ( W − M S A ) = ( 4 m m C 2 + 2 ( m m ) 2 C ) ∗ H W / m m = 4 h w C 2 + 2 M 2 h w C \begin{aligned} \Omega(\mathrm{W-MSA})&=(4 mm C^2+2(mm)^2 C) * HW / mm \\ & = 4hwC^2 +2M^2hwC \end{aligned} Ω(WMSA)=(4mmC2+2(mm)2C)HW/mm=4hwC2+2M2hwC

代码学习

世界名画
img

一整个swin transformer模块

基本遵循世界名画的顺序搭建出来的

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    # 初始化
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, fused_window_process=False, **kwargs):
        # num_heads=[3, 6, 12, 24]这里头数一直都×2, 是因为patch merging操作每次都让embedding乘2
        # depth 2 2 6 2 都是偶数
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        # patch embedding让224*224*3的图像变成56*56*96的图像(通过卷积搞的)
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        # 是否加绝对位置编码(类似于vision transformer或者传统transformer)
        # 这里加不加都行
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)  # dropout

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        # 这是等间隔地设置dropout rate
        # 包含了从 0 到 drop_path_rate 的等间隔的数字,这个等间隔的区间被分成了 sum(depths) 个部分


        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),  # 这里的embedding确实不断 * 2
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint,
                               fused_window_process=fused_window_process)
            # 以上layer构造四个阶段
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)  # 一个池化
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()  # 加上前面的池化就是构造多分类

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)  # 首先是patch embedding
        if self.ape:  # 看是否加绝对位置编码
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)  # 接受的是forward_features
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

patch partition & Linear embedding

这一步不如说降采样, 把H * W * 3变成H / 4, W / 4, 16C. 这一步是通过kernel_size = 4*4, stride = 4的卷积核得到的. 看文章patch不是一块一块的吗, 都是4 * 4大小的, 这里就是patch大小可以当成1*1, 也就是一个像素大小. 因为图像被缩小以后4*4的都堆叠在通道维度上. 所以严格来讲是对1*1的patch转化为embedding.
而且, patch partition和linear embedding都放在patch embedding当中, 通过一次卷积的映射. 也就是把224 * 224 * 3的转化56 * 56 * (48 ->)96

class PatchEmbed(nn.Module):
    # 这一块是最初的downsample
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) # 变成元组
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 这一块生成的是(H/4, W/4, embedding=96(3*4*4))
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
        # 这一块self.norm = norm_layer(embed_dim)
        # 创建一个归一化层,其输入特征的维度为embed_dim,然后将其赋值给类的self.norm属性。
        # 这样,在网络的前向传播过程中,可以使用self.norm来对嵌入的特征进行归一化处理

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        # 看图像尺寸是否与预期图像一致
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        # 上面是卷积核的映射,就是沿着第三个维度展平,56*56=3136
        if self.norm is not None:
            x = self.norm(x)
        # 做一个归一化
        return x

    def flops(self):
        # 计算浮点数
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        # flops = 56 * 56 * 96 * 3 * (4 * 4)
        # 常规卷积核的FLOPs计算:(这个是输出特征图每个像素的)(2 * k_h * k_w * c_in) * (这个是特征图整体尺寸的)(c_out * h_o * w_o) (卷积过程涉及先乘后加)
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

swin transformer block

看世界名画的successive 图可以看到, 先经过的是layernorm, 然后开始做自注意力. 这里自注意力机制作者搞了个windowAttention类来计算, 推荐看这个视频的讲解BV1bq4y1r75w. 不同于transformer和vision transformer, swin在窗口注意力中搞了个相对位置编码, 这一块不太好理解.

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 fused_window_process=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        # shift_size大于等于0;shift_size小于window_size。需要满足这两个条件,有一个不满足都会assertion error

        self.norm1 = norm_layer(dim)
        # 最重要的部分就是attn
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # 似乎是生成mask矩阵?

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # 然后是mask矩阵
        if self.shift_size > 0:
            # 大于0表示真的移动窗口了
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1 只有HW的这么一个张量,没有批次和通道,单张图片
            h_slices = (slice(0, -self.window_size),  # 加入window_size = 7,也就是7个patch,也就是0到-7
                        slice(-self.window_size, -self.shift_size),  # 这里是切出来-7 -3
                        slice(-self.shift_size, None))  # 这里是切出来 -3 none
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
                    # 这是给单张的image_mask给划分出来

            # 这里都是初始化
            mask_windows = window_partition(img_mask, self.window_size)  # nW * B, window_size, window_size, 1
            # mask_windows(num_windows*B, window_size, window_size, C=1)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # 这个矩阵变成了二维矩阵,把第0维看成数量,第1维堪称我们需要的
            # 这里的矩阵其实是相对位置矩阵,由于移动后的矩阵中包含不相关的窗口,把相同窗口的位置索引赋予相同的索引值, windowAttention计算的
            # 把相对位置矩阵非拉长变成一条,横着一条竖着一条,可以计算索引与索引之间的差值
            # 如果插值都等于0说明来源于同一个窗口,不为0说明这patch像素不存在相关性
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            # unsqueeze就是二维的基础上加一维,便于广播,就是横着一条有4个,竖着一条有4个,广播后就是4*4=16
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            # 这里就是如果为0则为0,不为0则为负无穷,就是mask
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
        self.fused_window_process = fused_window_process

        # 以上过程就是初始化, self.norm1, self.attn都是预先设置好, 方便后面用.
        # 在前面定义好mask, 也方便后面用
        # forward阶段, 就开始按部就班的把transformer的每个部分实现
        # 调用init中提出过的函数啥的


    def forward(self, x):
        H, W = self.input_resolution  # x: 1, 3136, 96   B = 1
        B, L, C = x.shape  # L = 56*56=3136
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)  # 层归一化
        x = x.view(B, H, W, C)  # -> 1, 56, 56, 96

        # cyclic shift
        # 这里就是如果有移动窗口的话, 就需要把像素移动一下, 沿着宽和高的方向移动, 然后打成patch
        if self.shift_size > 0:  # 窗口移动
            if not self.fused_window_process:
                # x: B, H, W, C
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # shifts=(-self.shift_size, -self.shift_size)是在HW两个维度上移动
                # # 这个roll函数很有意思, 就是用来把前面像素移动到后面
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        # 我明白了, 就是前面初始化把attn_mask只是临时计算了一下, 里面的值都是0. 如果shift大于0, 那么mask就能用
        # 然后调用windowattn, 在这里面计算相对位置索引, 然后计算偏置.

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        # 图像变回原来的形状

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                # window_reverse让窗口变回原来的形状(B H' W' C), 并且让移位的图像也变回去
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            # 如果没有移位, 那么就是window_reverse以后直接就是了
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)  # 残差连接, 还要对前面的x做一个dropout

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # 所以说果然是window_size * window_size, 对应前面attn计算flpops的hw * hw
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

首先是W-MSA和SW-MSA, 我们先把windowAttention当成一个黑盒子

  • 在初始化阶段, 先引入这个self.attn, 后面是dropout, layernorm和MLP, 初始化部分开始搞mask矩阵, 这个mask矩阵根据shift_size来判断

    • 如果size大于0说明我们真的移动窗口了, 也就是SW-MSA.

      • 这个时候我们需要构造一个mask矩阵, 为什么要构造呢, 是因为我们对图像做了移动, 把部分左侧像素点和上方像素点分别移动到了右侧和下侧, 这个移动过程是看forward. 这样的话, 移动后的像素不久和周围没啥相关性吗? 所以要把没相关性的异端像素弄成-inf, 这样防止做自注意力机制的时候把异端和原住民给混一块

      • 具体来说, 如果我给这个图划分成好几个小窗口, 我们可以看到右边和下边的几个窗口明显是有异端的, 那我得给人家标号对吧, 标号后有什么作用我后面会提到.
        请添加图片描述

      • 这个mask矩阵我们就当成(1, H, W, 1)大小, 没有批没有通道的一张图片, 56 * 56的. 我们怎么找到没有相关性的像素呢? 答案是把移动后的mask矩阵给标号, 打个比方说, 给图片划分国家, 你是你我是我. mask矩阵处理过程看代码注释

      • 为什么切出来(-7, -3), (-3, None), 这是因为每次移动都是移动3个像素到最后面, 但是我们知道, 窗口大小是7*7, 那就糟糕了, 一个窗口里混出3列或者3排异端(看上面的图), 那我就要区分出来啊

      • h_slices = (slice(0, -self.window_size),  # 加入window_size = 7,也就是7个patch,也就是0到-7
                    slice(-self.window_size, -self.shift_size),  # 这里是切出来-7 -3
                    slice(-self.shift_size, None))  # 这里是切出来 -3 none
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        
      • 区分后, 我以右下角的窗口为例, 右下角窗口图像鱼龙混杂. 首先通过window partition函数, 这个函数就是把图像打成一个个小window, 维度是(B, H, W, C) -> (B, H // window_size(nWindow'), window_size, W // window_size(nWindow'), window_size, C) -> (B, nWindow', nWindow', window_size, window_size, C) -> (B*nWindow, window_size, window_size). 这种小window便于做窗口自注意力机制.

      • def window_partition(x, window_size):
            """
            Args:
                x: (B, H, W, C)
                window_size (int): window size
        
            Returns:
                windows: (num_windows*B, window_size, window_size, C)
            """
            B, H, W, C = x.shape
            x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
            windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
            return windows
            # B, H // window_size, window_size, W // window_size, window_size, C
            # -> (num_windows*B, window_size, window_size, C)
            # 其中,num_windows = (H // window_size) * (W // window_size)
            # 也就是打成window,window在第0个维度上组成(这个过程非常像patch)
        
      • 下面是生成窗口mask的代码, 也就是attn_mask, 维度变化(B*nWindow, window_size, window_size) -> (B*nWindow, window_size*window_size), 把这些小窗口的mask变成二维以后, 把他给展开, 如下图所示(DASOU这个视频 BV1zT4y197Fe 非常建议看一下, 看完之后再把代码捋一遍基本就通了, 不过前提是看了transformer和Vision transformer). 我们前面不是给mask矩阵标号了吗, 一个国家一个标号, 假设现在是2*2的窗口, 把这个窗口拉直, 然后计算每个元素之间的差值, 如果相差为0说明他们是一个区域的元素, 如果不为0就说明不是一个区域的, 所以没有相关性,需要把他们盖住, 驱逐他们! (就是弄成负无穷, 这样在后面计算自注意力的softmax的时候, 它们的值无限小, 也不会影响结果, 也不会影响到相关性). 这个拉长的过程用了unsqueeze函数, 自行百度unsqueeze, 简单来说就是增加一个维度, 然后广播机制. 请添加图片描述

      • # 这里都是初始化
        mask_windows = window_partition(img_mask, self.window_size)  # nW * B, window_size, window_size, 1
        # mask_windows(num_windows*B, window_size, window_size, C=1)
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        # 这个矩阵变成了二维矩阵,把第0维看成数量,第1维堪称我们需要的
        # 这里的矩阵其实是相对位置矩阵,由于移动后的矩阵中包含不相关的窗口,把相同窗口的位置索引赋予相同的索引值, windowAttention计算的
        # 把相对位置矩阵非拉长变成一条,横着一条竖着一条,可以计算索引与索引之间的差值
        # 如果插值都等于0说明来源于同一个窗口,不为0说明这patch像素不存在相关性
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        # unsqueeze就是二维的基础上加一维,便于广播,就是横着一条有4个,竖着一条有4个,广播后就是4*4=16
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        # 这里就是如果为0则为0,不为0则为负无穷,就是mask
        
    • 如果说没移动, 那就没有这么多b事了, 不需要attn_mask

  • forward

    • 前面把mask给整出来了, 我前向传播就要搞点实在的, 如果说需要移动窗口, 那我就要用torch.roll函数. 这个函数把某个维度的前几部分移动到后面, 看下面的示例就能看懂

    • import torch
      
      x = torch.arange(80).reshape((1, 4, 4, 5))
      print(x)
      print('------------------------------------')
      # 使用 torch.roll() 进行循环移位
      shifted_x = torch.roll(x, shifts=(-3, -2), dims=(1, 2))
      print(shifted_x)
      
      输出
      tensor([[[[ 0,  1,  2,  3,  4],
                [ 5,  6,  7,  8,  9],
                [10, 11, 12, 13, 14],
                [15, 16, 17, 18, 19]],
      
               [[20, 21, 22, 23, 24],
                [25, 26, 27, 28, 29],
                [30, 31, 32, 33, 34],
                [35, 36, 37, 38, 39]],
      
               [[40, 41, 42, 43, 44],
                [45, 46, 47, 48, 49],
                [50, 51, 52, 53, 54],
                [55, 56, 57, 58, 59]],
      
               [[60, 61, 62, 63, 64],
                [65, 66, 67, 68, 69],
                [70, 71, 72, 73, 74],
                [75, 76, 77, 78, 79]]]])
      ------------------------------------
      tensor([[[[70, 71, 72, 73, 74],
                [75, 76, 77, 78, 79],
                [60, 61, 62, 63, 64],
                [65, 66, 67, 68, 69]],
      
               [[10, 11, 12, 13, 14],
                [15, 16, 17, 18, 19],
                [ 0,  1,  2,  3,  4],
                [ 5,  6,  7,  8,  9]],
      
               [[30, 31, 32, 33, 34],
                [35, 36, 37, 38, 39],
                [20, 21, 22, 23, 24],
                [25, 26, 27, 28, 29]],
      
               [[50, 51, 52, 53, 54],
                [55, 56, 57, 58, 59],
                [40, 41, 42, 43, 44],
                [45, 46, 47, 48, 49]]]])
      
    • 这个是图像移动的代码, 得到移动后的窗口x_windows, 维度为(# nW*B, window_size, window_size, C)

    • if self.shift_size > 0:  # 窗口移动
          if not self.fused_window_process:
              # x: B, H, W, C
              shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
              # shifts=(-self.shift_size, -self.shift_size)是在HW两个维度上移动
              # # 这个roll函数很有意思, 就是用来把前面像素移动到后面
              # partition windows
              x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
          else:
              x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
      else:
          shifted_x = x
          # partition windows
          x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
      
    • 接下来把x_windows弄成维度(nW*B, window_size*window_size, C). 现在, 我们拥有移动好的窗口x_windows, 还有在初始化部分搞得attn_mask, 我们就可以做窗口自注意力了(通过windowAttention, 这个函数在初始化部分已经声明过了self.attn()).

    • x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
      
      # W-MSA/SW-MSA
      attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
      # 我明白了, 就是前面初始化把attn_mask只是临时计算了一下, 里面的值都是0. 如果shift大于0, 那么mask就能用
      # 然后调用windowattn, 在这里面计算相对位置索引, 然后计算偏置.
      
      # merge windows
      attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
      # 图像变回原来的形状
      
    • 接下来还要把窗口还回来, 需要做一次window_reverse, 还原成原先的(B, H, W, C), 维度计算主要经历了(B, num_window, num_window, window_size, window_size, C) -> (B, num_window, window_size, num_window, window_size, C) -> (B, H, W, C)

    • def window_reverse(windows, window_size, H, W):
          """
          Args:
              windows: (num_windows*B, window_size, window_size, C)
              window_size (int): Window size
              H (int): Height of image
              W (int): Width of image
      
          Returns:
              x: (B, H, W, C)
          """
          B = int(windows.shape[0] / (H * W / window_size / window_size))
          x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
          x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
          # (B, num_window, num_window, window_size, window_size, C)
          # -> (B, num_window,, window_size, num_window, window_size, C)
          # -> (B, H, W, C)
          return x
      
      • 此外, 如果shift过, 除了需要reverse, 还需要反过来再roll一遍, x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)), 前面是(-self.shift_size, -self.shift_size), 这里都是(self.shift_size, self.shift_size)
      • 如果没有shift过那当然就不需要reverse啦
      • 其实我自己有个疑问, fused_window_process真的飞机非加不可吗, 我看视频的时候, up主的源码版本就没有这玩意, 如果有大佬明白希望能解惑.
    • # reverse cyclic shift
      if self.shift_size > 0:
          if not self.fused_window_process:
              shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
              # window_reverse让窗口变回原来的形状(B H' W' C), 并且让移位的图像也变回去
              x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
          else:
              x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
      else:
          # 如果没有移位, 那么就是window_reverse以后直接就是了
          shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
          x = shifted_x
      x = x.view(B, H * W, C)
      x = shortcut + self.drop_path(x)  # 残差连接, 还要对前面的x做一个dropout
      
    • 最后就是i计算flops之类的了

    • def flops(self):
          flops = 0
          H, W = self.input_resolution
          # norm1
          flops += self.dim * H * W
          # W-MSA/SW-MSA
          nW = H * W / self.window_size / self.window_size
          flops += nW * self.attn.flops(self.window_size * self.window_size)
          # 所以说果然是window_size * window_size, 对应前面attn计算flpops的hw * hw
          # mlp
          flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
          # norm2
          flops += self.dim * H * W
          return flops
      

windowAttention

这部分算是精华了. 先上代码(代码里面有我的很多注释, 可以看一下):

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        # dim是输入维度,
        # 这里面还有dropout,有意思的技术细节
        # qk_scale就是dk
        # num_heads随着后面patch merging增加通道数,也就是embedding维数,为了让每个头接受的数据量一致,所有增加num_head
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        # 这里,如果传入dk也就是qk_sclae,那么就是根号dk;如果不传入,就是根号head_dim

        # define a parameter table of relative position bias
        # 相对位置参数
        # 就是softmax(qk^t / dk)*V + b
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # get pair-wise relative position index for each token inside the window
        # 这里是计算相对位置的索引
        # 这个索引是个固定的,也就是不可学习的,所以最后放在register_buffer里
        # 然而,相对位置偏表B是learnable,B的索引如果与相对位置矩阵中的值对应,就要把这个索引的值放在相对位置矩阵这个值上
        # for i, value in enumerate(table):
        #     position[torch.where(position == i)] = value
        coords_h = torch.arange(self.window_size[0])  #
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        # 相对位置编码以后的常规操作
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # nn.Linear(输入特征, 输出特征,是否加入偏置)
        # 输出特征*3是为了要输出查询、键和值三个张量
        # 定义了一个线性映射层,
        self.attn_drop = nn.Dropout(attn_drop)
        # dropout的作用是随机失活
        self.proj = nn.Linear(dim, dim)
        # 映射成原来的特征大小
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
        # 以上都是初始化过程

    # 向前传播
    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        # 下面线性变换的时候是在通道上堆叠QKV
        # N 表示序列长度(sequence length)
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # permute重排列
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # 把最后两个维度颠倒过来
        # patch数量也就是window大小 * window大小, window大小 * window大小

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        # # TODO: 可不可以这么写代码
        # # 我可不可以
        # for i, value in enumerate(self.relative_position_bias_table):
        #     self.relative_position_index[torch.where(self.relative_position_index == i)] = value
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            # mask 维度:  num_windows, Wh*Ww, Wh*Ww  (Wh = seq_len * seq_len)
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # mask维度为[1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N, token length应该就是7*7=49这种的, 然后计算49*49
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N  # 其实就是计算复杂度那块, hw * hw * c
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)  # hw * hw * c
        # x = self.proj(x)
        flops += N * self.dim * self.dim  # hw * c * c
        return flops

请添加图片描述

  • 简单来说, 就是搞了(2, window_size, window_size)的图像, 一个按照行填充0123, 一个按照列填充, 然后拉直, 再增加维度后, 广播机制相减, 图里面的M是window_size, 对行列标做一些计算以后, 把这两片张量相加, 我们会发现张量里面的数字范围在[2M-1, 2M+1], 然后我们初始化一个偏表, 偏表里面放了(2M-1) * (2M+1)个数字, 矩阵的值作为索引, 找到偏表里面的值, 然后放进去.

  • 在初始化过程, 就是生成了相对位置索引, 也就是把这个张量沿着第0个维度相加, 我称之为两片小张量相加.

  • coords_h = torch.arange(self.window_size[0])  #
    coords_w = torch.arange(self.window_size[1])
    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
    relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
    relative_coords[:, :, 1] += self.window_size[1] - 1
    relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
    self.register_buffer("relative_position_index", relative_position_index)
    
  • 接下来就是一些常规操作, 比如说把QKV放在一起, 然后dropout, 再就是映射.

  • # 相对位置编码以后的常规操作
    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    # nn.Linear(输入特征, 输出特征,是否加入偏置)
    # 输出特征*3是为了要输出查询、键和值三个张量
    # 定义了一个线性映射层,
    self.attn_drop = nn.Dropout(attn_drop)
    # dropout的作用是随机失活
    self.proj = nn.Linear(dim, dim)
    # 映射成原来的特征大小
    self.proj_drop = nn.Dropout(proj_drop)
    
    trunc_normal_(self.relative_position_bias_table, std=.02)
    self.softmax = nn.Softmax(dim=-1)
    # 以上都是初始化过程
    
  • forward

    • forward首先把计算 Q K T QK^T QKT, 这里面涉及到维度计算(维度: num_windows*B, 序列长度(?是什么?), qkv三个, 头数, 每个头的维度) -> (qkv三个, num_windows*B, 头数, 序列长度(H*W), 每个头的维度)

    • # 向前传播
      def forward(self, x, mask=None):
          """
          Args:
              x: input features with shape of (num_windows*B, N, C)
              mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
          """
          B_, N, C = x.shape
          # 下面线性变换的时候是在通道上堆叠QKV
          # N 表示序列长度(sequence length)
          qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # permute重排列
          q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
      
          q = q * self.scale
          attn = (q @ k.transpose(-2, -1))  # 把最后两个维度颠倒过来
          # patch数量也就是window大小 * window大小, window大小 * window大小
      
    • 接下来就i算relative_position_bias, 首先是通过nn.Parameters来生成learnable参数, 然后根据前面的索引张量, 就是说这个张量里面的值是索引, 在偏表里面找到对应的值, 填在表里面, 而这个表就是(window_size*window_size, window_size*window_size)大小的(nH是头数). 然后在注意力窗口上加入这个相对位置偏置(通过unsqueeze添加一个维度来进行广播)

    • relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
      relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
      attn = attn + relative_position_bias.unsqueeze(0)
      
    • 接下来就需要计算softmax, 前面都是先计算相似度和相对位置, 接下来把mask加上去(如果没有就不加), mask维度调整为(1, nW, 1, Mh*Mw, Mh*Mw),

    • 后面就是线性映射和dropout之类的了

    • if mask is not None:
          nW = mask.shape[0]
          # mask 维度:  num_windows, Wh*Ww, Wh*Ww  (Wh = seq_len * seq_len)
          attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
          # mask维度为[1, nW, 1, Mh*Mw, Mh*Mw]
          attn = attn.view(-1, self.num_heads, N, N)
          attn = self.softmax(attn)
      else:
          attn = self.softmax(attn)
      
          x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
          x = self.proj(x)
          x = self.proj_drop(x)
          return x
      

Patch Merging

这个部分就是用来降采样的, 先上代码

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        # 偶数个元素数量就没问题
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)  # 线性变换

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

这个没啥讲的, X输入维度是(B, H*W, C) -> (B, H, W, C)就是每次都让HW少一半, C增加至原来的四倍. 具体地说, 就是隔一个取一个元素, 然后沿着通道维度拼起来

x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

Basic layer

和世界名画不太一样, 这里是包括SwinTransformerBlockPatchMerging两个, 把SW模块堆叠起来之后后面放一个patchmerging, oatchmerging就相当于downsample

SwinTransformer

最后就是按照世界名画的样子叠起来

  • 23
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值