Swin Transformer框架结合代码精讲

Swin Transformer 是2021.3在ICCV发表的一篇论文,同时也是这一年ICCV的best paper,在各大检测、分割任务中有着非常出色的结果。

论文地址:https://arxiv.org/pdf/2103.14030.pdf

论文官方代码:https://github.com/microsoft/Swin-Transformer

目录

整体架构

 Swin Transformer Block

Patch Partition

Linear Embedding

Transformer Block

Window based Self-Attention

相对位置编码

Shifted Window Attention 

特征图位移操作

 Attention Mask

 PatchMerging

实验结果

总结


整体架构

        整个模型采取层次化的设计,一共包含4个Stage,与ViT不同,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野,这样解决了多尺寸目标检测的问题,为本文的一大创新。

        这里放一个Swin 和ViT的对比图,这其实可以看成作者对ViT的一个改进,ViT从头至尾都是对全局做self-attention,而swin-transformer是一个窗口在放大的过程,然后self-attention的计算是以窗口为单位去计算的,这样相当于引入了局部聚合的信息,和CNN的卷积过程很相似,就像是CNN的步长和卷积核大小一样,这样就做到了窗口的不重合,区别在于CNN在每个窗口做的是卷积的计算,每个窗口最后得到一个值,这个值代表着这个窗口的特征。而swin transformer在每个窗口做的是self-attention的计算,得到的是一个更新过的窗口,然后通过patch merging的操作,把窗口做了个合并,再继续对这个合并后的窗口做self-attention的计算。

 Swin Transformer Block

         可以注意到这4个stage的Swin Transformer Block都有×2或×6,是2的倍数,因为这个是两个Successive block组合的,两个block的不同在于一个是W-MSA(window multi-head self attention),另一个是SW-MSA(shifted window multi-head self attention),后面会分别详细讲解。

Patch Partition

        首先,输入图像H×W×3,输入到Patch Partition模块,在代码中是PatchEmbed类实现的,我们来看一下PatchEmbed的forward()函数:

    def forward(self, x):
        """Forward function."""
        # padding  确保 H、W为patch_size的整数倍
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C H/4 W/4  C=embed_dim
        if self.norm is not None:   # 下面是一个normalization的一个操作  
            Wh, Ww = x.size(2), x.size(3)   # B C H/4 W/4
            x = x.flatten(2).transpose(1, 2) # (B, H/4×W/4, embed_dim)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
                                                # (B, embed_dim, H/4 W/4)
        return x  # (B, embed_dim, H/4, W/4)

         可以看到PatchEmbed的核心代码就是self.proj()函数,如下:

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
x = self.proj(x)

        而self.proj()是一个卷积函数,输出的通道数为embed_dim,卷积核的大小就是patch_size,这样下来,就是把输入图像通过一个4×4的卷积,H、W分别缩小了4倍。

Linear Embedding

         在输入到Transformer网络之前,要把维度变为(B,n_tokens, embed_dim)的格式,才能进行接下来的multi-head self attention,对经典的transformer网络不熟悉的同学可以去把transformer看懂了再过来,这里就不讲了。其实这个代码很简单,输入维度是(B, embed_dim, H/4, W/4),这样我们的token的数目就是H/4 × W/4 = HW/16,而每一个token的维度就是我们初始化网络输入的embed_dim,这里有不同尺寸的Swin网络(Swin-T,Swin-S,Swin-B,Swin-L),论文里面参数设置如下:

  • Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
  • Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
  • Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
  • Swin-L: C = 192, layer numbers ={2, 2, 18, 2}

        这里的C就是 embed_dim,后面的layer numbers对应每个stage的Swin Transformer的数目。所以只需要一行代码就能实现维度转换(B,n_tokens, embed_dim)的操作:

x = x.flatten(2).transpose(1, 2)

        当然,  Linear Embedding还有一个位置编码的操作,论文这里用了absolute position embedding。代码里面绝对位置编码就是随机初始化参数,与x的Batch维度以外的相匹配,作为可学习的位置编码参数,如下:

self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))

        接着,x与位置编码相加,再进行flatten和transpose的操作,如下: 

x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # (B, H/4*W/4, embed_dim)

Transformer Block

        Linear Embedding之后就输入到Swin Transformer Block,首先输入的是Windows的多头注意力机制Transformer Block;接着输入的是shifted Windows的多头注意力机制Transformer Block。

Window based Self-Attention

        首先介绍Window based Self-Attention,这个Attention简单来说就是CNN版的ViT,把特征图分为7*7(假设设置的窗口大小为7*7)大小的窗口,对每一个窗口进行Linear Embedding的操作,然后像ViT一样,输入多头注意力机制网络,如图:

https://gitee.com/xiaomoon/image/raw/master/Img/image-20210530155145008.png

        window_partition代码如下:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape 
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows  (windows_num, window_size, window_size, embedding_dim)

        然后把把每个小window的token展平:

x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # (nW*B, window_size*window_size, embedding_dim)

        接下来输入到W-MSA网络,和普通的Transformer网络一样,由(B, token_num, embedding_dim) 得到qkv,然后注意力矩阵attn=q · k.transpose,输出x=attn·v;值得注意的是,我们这里B=windows_num * Batchsize, token_num=window_size*window_size。另外更值得一提的是,这里的注意力矩阵attn是加上了相对位置编码的,后续论文的实验有证明相对位置编码提升了模型性能。WindowAttention 的forward函数:

 def forward(self, x, mask=None):
        """ Forward function.

        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  #(B*windows_num, head_num, tokens_num, tokens_num)
                                           #  tokens_num = windows_size * windows_size

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

相对位置编码

这里面比较难理解的是相对位置编码是怎么定义的,我们可以在上面看到相对位置编码是与注意力矩阵attn相加的,这里atten的维度为(B*windows_num, head_num, tokens_num, tokens_num),我们要得到与此维度相同的相对位置编码矩阵;

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        很明显,这里用到了初始化window类的两个self变量--self.relative_position_bias_table和self.relative_position_index,先介绍前面relative_position_bias_table,

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        trunc_normal_(self.relative_position_bias_table, std=.02)

        它是通过nn.Parameter初始化,然后变为方差std为0.02的正态分布数据,它是一个二维的tensor数组,shape[0]为(2*window_h-1)*(2*window_w-1),假设windows的高宽都为7,那么shape[0]=13*13=169;shape[1]为多头注意力head的数目。

        重点是self.relative_position_index是怎么定义的,看名字知道它是一个位置索引数组。首先,在一个二维的Windows中,每一个token的位置是一个二维的坐标,但是进入transformer网络后,二维的token要展平成一维的,这就意味着需要把二维的相对位置距离转为一维的,以下是相对位置编码index的代码:

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

        我们这里假设windows_h=windows_w=3,一步一步来看:

window_size = [3, 3]
coords_h = torch.arange(window_size[0])  # 0,1,2
coords_w = torch.arange(window_size[1])  # 0,1,2 
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))

## coords:
tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])

       上面是横坐标,下面是纵坐标,合起来看的话就是这样:

我们给每一个位置编号如左图

        接下来flatten的操作:

coords_flatten = torch.flatten(coords, 1) # 2, Wh*Wh
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
上面是横坐标,下面是纵坐标

        接下来,利用广播机制,分别在第一维,第二维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww的张量 

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww

        如上,我们可以看到这个两个tensor,右边的就是左边对应横坐标已经纵坐标的转置,两个tensor相减之后,就得到了每一个位置的分别相对于其他位置的距离,个人看这一方法有类似与求自相关矩阵。我们把结果打印出来看看:

tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],  横坐标相对偏移

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])   纵坐标相对偏移

        我们可以看到对角线都是0,这一点很好解释,3*3窗口中的9个token对自己的偏移都是0。好,我们把横纵坐标合在一起,如下:

tensor([[[ 0,  0],[ 0, -1],[ 0, -2],[-1,  0],[-1, -1],[-1, -2],[-2,  0],[-2, -1],[-2, -2]],

        [[ 0,  1],[ 0,  0],[ 0, -1],[-1,  1],[-1,  0],[-1, -1],[-2,  1],[-2,  0],[-2, -1]],

        [[ 0,  2],[ 0,  1],[ 0,  0],[-1,  2],[-1,  1],[-1,  0],[-2,  2],[-2,  1],[-2,  0]],

        [[ 1,  0],[ 1, -1],[ 1, -2],[ 0,  0],[ 0, -1],[ 0, -2],[-1,  0],[-1, -1],[-1, -2]],

        [[ 1,  1],[ 1,  0],[ 1, -1],[ 0,  1],[ 0,  0],[ 0, -1],[-1,  1],[-1,  0],[-1, -1]],

        [[ 1,  2],[ 1,  1],[ 1,  0],[ 0,  2],[ 0,  1],[ 0,  0],[-1,  2],[-1,  1],[-1,  0]],

        [[ 2,  0],[ 2, -1],[ 2, -2],[ 1,  0],[ 1, -1],[ 1, -2],[ 0,  0],[ 0, -1],[ 0, -2]],

        [[ 2,  1],[ 2,  0],[ 2, -1],[ 1,  1],[ 1,  0],[ 1, -1],[ 0,  1],[ 0,  0],[ 0, -1]],

        [[ 2,  2],[ 2,  1],[ 2,  0],[ 1,  2],[ 1,  1],[ 1,  0],[ 0,  2],[ 0,  1],[ 0,  0]]])

         解释一下第一行是什么意思,[0,0]是指窗口里面的第一个token相对于自己,x轴的偏移与y轴的偏移都是0;[0,-1]是指window里的第2个token相对于第一个token,x轴的偏移是0,y轴的偏移是-1,以此类推···

        那么第二行,[0,1]是指window里的第1个token相对于第2个token,x轴的偏移是0,y轴的偏移是1;[0,-1]是指window里的第3个token相对于第2个token,x轴的偏移是0,y轴的偏移是-1,以此类推···

        这样,我们把x轴和y轴的偏移加起来,就是相对位置的总的偏移,当然,这里面x轴与y轴相加,会有负数的情况,代码里面为了不让出现负数,即相对位置偏移从0开始,进行了如下操作:

relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

         即,x与y都加上(3-1)=2

        当然,代码还有一个操作,就是对所有的x偏移都乘上了(2 * window_size[1] - 1)这个数,这里window_size[1]=3时,就是5;

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

         最后对x轴与y轴求和,就是最后相对位置偏移的index:

relative_position_index = relative_coords.sum(-1)

tensor([[12, 11, 10,  7,  6,  5,  2,  1,  0],
        [13, 12, 11,  8,  7,  6,  3,  2,  1],
        [14, 13, 12,  9,  8,  7,  4,  3,  2],
        [17, 16, 15, 12, 11, 10,  7,  6,  5],
        [18, 17, 16, 13, 12, 11,  8,  7,  6],
        [19, 18, 17, 14, 13, 12,  9,  8,  7],
        [22, 21, 20, 17, 16, 15, 12, 11, 10],
        [23, 22, 21, 18, 17, 16, 13, 12, 11],
        [24, 23, 22, 19, 18, 17, 14, 13, 12]])

       拿对角线的12算,就是12 = (0+2)*5+(0+2)=12,其他位置的偏移也是这么算的。

        前面我们得到了一个(2*window_h-1)*(2*window_w-1) ,head_num即(25,3)的relative_position_bias_table,上面的最大index就是24,即数组的第25个,正好对应上了。所以经过如下操作,就得到了相对位置编码:

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        至于相对位置编码为什么要这样排列,一个Window里面只有9个patch,为什么要弄出这么一个9*9的位置矩阵?其实就是要跟attn矩阵对应起来,我们的attention矩阵就是由9个初始的patch分别得到9个q、k、v,然后q和k做相关矩阵,从另外一个角度讲就是自己跟自己做自相关,这一点不明白的可以先把transformer看明白,比如自相关矩阵第一行,就是每个patch分别与第一个patch的相关度,而第二行就是每个patch与第二个patch的相关度;所以这里的位置编码要这么排列,才能和attn矩阵的位置相对应,这样才能相加。

Shifted Window Attention 

        Shifted Window Attention 算是论文的特色创新点之一了,上面解释的Window Attention是在每个窗口下计算注意力的,为了更好的和其它窗口进行信息交互,Swin Transformer还引入了shifted window操作。

        以论文中的图为例子,左图有4个window,每一个window有4*4个patch(就是token),左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。这里的shift的长度在代码中是这样的

                shift_size=0 if (i % 2 == 0) else window_size // 2,

        它的意思是偶数序号block 的shift_size=0,即进行普通的Window Attention;基数序号block 的shift_size=window_size//2,即窗口尺寸的一半(严谨点,window_size是偶数情况);但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。

特征图位移操作

         如上图,论文图解是把右下角的A、B、C移到了左上,在代码里面是通过roll函数实现的:

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))  # # (B, H/4, W/4, embed_dim)
            attn_mask = mask_matrix

         

        图3和图1我们可以看到,它们有相同的数目的window,但是每一个window的内容却不一样。

 Attention Mask

        其实到这里,Shifted Window Attention已经很明朗了,把上面经过roll的图(上面图3)和W-MSA一样,做相同的操作就行了;但是,论文中有一个masked MAS,那么这个mask是什么功能。

         我们先画一个经过roll之后的图,画一个大一点的方便理解:

         如图8*8个大小为4*4的window,经过roll之后,在边界会出现这样的windows:

        在边界的Window会同时具有来自于原图像(橙色、浅蓝色、褐色)的像素。但是由于这些window的不同颜色的像素,在源图像中两者本来就距离较远,所以在提取局部特征的时候,论文认为二者并不适合被认为互为邻域,不适合放在一起计算attention,为了结局这一问题,论文提出了mask 的Attention,我们来看看代码是这么做的。

        拿2*2个大小为2*2的window的图像为例子,我们先对shift之后的区域编一个号:

         先拿连续区域的1号区域为例子,看看attn矩阵是这么算的;首先我们把里面的4个patch展开,经过全连接得到q、k、v的3个矩阵,q 与k的转置进行矩阵相乘之后,得到atten矩阵:

         然后第二个window的,将patch展开,q 与k的转置进行矩阵相乘之后,得到atten矩阵:

        同理第3个window和第4个window的结果,如下:

        上图中有颜色的区域是连续区域做相乘的,二没有颜色的区域,是在原图中相隔很远的区域相乘的结果,论文希望最后attn 与 v 相乘后,能够把没有颜色的区域忽略掉,只算相邻像素的自相关,于是代码设置了和attn同意大小的mask,有颜色区域设置为0,没有颜色区域设置为-100,在attention计算过程中,将得到的attn矩阵与mask相加,由于对该矩阵加了softmax所以相互对应位置均为0,也就是说两个不同的部分相互之间不会参与attention的计算,达成了隔离计算的目的。

tensor([[[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],


         [[[   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.],
           [   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.]]],


         [[[   0.,    0., -100., -100.],
           [   0.,    0., -100., -100.],
           [-100., -100.,    0.,    0.],
           [-100., -100.,    0.,    0.]]],


         [[[   0., -100., -100., -100.],
           [-100.,    0., -100., -100.],
           [-100., -100.,    0., -100.],
           [-100., -100., -100.,    0.]]]]])

        mask的代码如下:

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size  # np.ceil(a) ,np.floor(a) : 计算各元素的ceiling 值, floor值(ceiling向上取整,floor向下取整)
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1 (1,182,322,1)
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

         在代码里,第二个block的attn矩阵与mask相加如下:

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)

        至此,当第二个shift的W-MSA结束后,会有一个reverse shift的操作,就是把之前roll的像素,再roll还原;

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

 PatchMerging

        上面经过一个stage后,会有一个PatchMerging 的操作,其实就是downsample,将特征图缩小到原来的一半;废话不说,附上PatchMerging的forward代码:

 def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

         画图形象一点,如下:

         然后把得到的4张图cat到一起,就完成了下采样的操作,当然,得到的图的channel增加了4倍,后面经过view成一位的向量之后,通过全连接层,将维度缩小两倍,代码如下:

x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

        接下来就进入下一个stage,各部分功能代码是一样的,重复就是。

实验结果

        论文代码的各项检查以及分割的指标非常好, 在ImageNet-1K达到了84.2%,在ImageNet-2K精度更是达到了86.4%,在分割任务ADE20K的mIoU达到了53.5%。

总结

        这篇论文创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量,通过transformer得到了多尺度大小的特征图,方便了后续的分类以及分割的任务。相对位置编码部分,代码设计的也很是巧妙;在Shift Window Attention部分,用一个mask来进行非连续像素块的Attention矩阵的筛选,很是巧妙,论文非常推荐阅读,另外,文章制作不易,欢迎点赞收藏!

  • 12
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值