【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分

先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attention) 和 SW-MSA是关键组成部分。W-MSA出现在某阶段的奇数层,SW-MSA出现在某阶段的偶数层,W-MSA考虑的是单个窗口的信息,SW-MSA考虑的是不同窗口间的信息。

在这里插入图片描述

虽然从网络架构图里看,W-MSA和SW-MSA为两个不同的模块,但是在代码层面,两者是同一个代码片段,只是在计算SW-MSA时候,在计算完W-MSA后,然后通过代码进行滑动窗口,即cyclic shift操作,多计算了一个mask的操作。下面将针对代码进行分析。

W-MSA的代码

【注意】注释第一句话:Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.
代码注释中的中文,是以配置文件中 swin-tiny 相关的量 来进行注释的。

#窗口注意力
class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim#96*(2^layer_index 0,1,2,3...)
        self.window_size = window_size  # Wh, Ww (7,7)
        self.num_heads = num_heads#[3, 6, 12, 24]
        head_dim = dim // num_heads#(96//3=32,96*2^1 // 6=32,...)
        self.scale = qk_scale or head_dim ** -0.5#default:head_dim ** -0.5

        # define a parameter table of relative position bias
        #定义相对位置偏置表格
        #[(2*7-1)*(2*7-1),3]
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        #得到一对在窗口中的相对位置索引
        coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        #让相对坐标从0开始
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        #relative_coords[:, :, 0] * (2*7-1)
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        #为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
        #索引值为 (49,49) 值在0-168对应位置偏移表的索引 
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
        #dim*(dim*3)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        #attn_drop=0.0
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        #初始化相对位置偏置值表(截断正态分布)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
    #模块的前向传播
    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape#输入特征的尺寸
        #(3, B_, num_heads, N, C // num_heads)
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # q/k/v: [B_, num_heads, N, C // num_heads]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # q*head_dim ** -0.5
        q = q * self.scale
        # attn:B_, num_heads,N,N
        attn = (q @ k.transpose(-2, -1))
        # 在 随机在relative_position_bias_table中的第一维(169)选择position_index对应的值,共49*49个
        #由于relative_position_bias_table第二维为 nHeads所以最终表变为了 49*49*nHead 的随机表
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        #attn每一个批次,加上随机的相对位置偏移 说民attn.shape=B_,num_heads,Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)
        #mask 在某阶段的奇数层为None 偶数层才存在
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        #进行 dropout
        attn = self.attn_drop(attn)
        #attn @ v:B_, num_heads, N, C/num_heads 
        #x: B_, N, C 其中 
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        #经过一层全连接
        x = self.proj(x)
        #进行drop out
        x = self.proj_drop(x)
        return x
关于W-MSA中的注意力机制的运算,其实就是按照下面这个公式来进行的,在这个公式里,其实 QKV 三者均是又输入经过一个全连接层(nn.Linear())得到的,这个在代码里很好看明白。关键是在W-MSA中,增加了一个位置偏移量 B,这里的B相关计算也是W-MSA中的关键一步,下面进行记录下。

在这里插入图片描述
在这里插入图片描述

位置偏移量 B 的代码详解

这里关键是理解 relative_position_bias_table 和 relative_position_index ,这两个矩阵的对应关系设计的比较巧妙,即relative_position_index 刚好设计为 relative_position_bias_table 所对应的网格数量

        # define a parameter table of relative position bias
        #定义相对位置偏置表格
        #[(2*7-1)*(2*7-1),3]
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        #得到一对在窗口中的相对位置索引
        coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        #让相对坐标从0开始
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        #relative_coords[:, :, 0] * (2*7-1)
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        #为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
        #索引值为 (49,49) 值在0-168对应位置偏移表的索引 
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        #注册为不可学习变量
        self.register_buffer("relative_position_index", relative_position_index)
        #dim*(dim*3)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        #attn_drop=0.0
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        #初始化相对位置偏置值表(截断正态分布)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
relative_position_bias_table :设置的一个可学习的 (2 x window_size[0]-1)x(2 x window_size[1]-1) x nHeads 的随机变量(利用截断正态分布赋值),如果以代码中第一个阶段的参数量为例,则 window_size[0]=window_size[1]=7, 在第一个阶段 nHeads=3 。即该表中存储的时候一系列的随机数,用于位置编码,提升模型的性能。

在这里插入图片描述
下面展示一个relative_position_bias_table的例子
在这里插入图片描述

relative_position_index :相对位置编码表的索引表,即存储的值,用来取得相对位置偏移量表relative_position_bias_table 中某个位置的值,relative_position_index中存的值所取范围为 [0,168],即relative_position_bias_table 的大小为 169(13 x 13)个单元格。通过下面的图片,可以看到 relative_position_index中0 和 168 位置的编码只取一次,其实符合传统transformer中对于位置编码的运用,即开头和结尾的位置编码只用一次。

在这里插入图片描述

关注到最后计算出的 attn 需要加上位置偏移量,则这里需要看一下 relative_position_bias的计算策略,即下面的图示

在这里插入图片描述relative_position_bias的计算策略:
在这里插入图片描述
最终的 relative_position_bias, 即经过转置后和 attn 的后三维一致,进而就可以进行直接位置相加了。
在这里插入图片描述
在这里插入图片描述

SW-MSA (Shifted Window based multi-head self attention (SW-MSA) module )

SW-MSA的代码中关键步骤为 attn_mask 和 shift windows的操作,即通过对特征图移位,并给Attention设置mask来间接实现的,在保持原有的window个数下,节省计算。

首先来看 attn_mask

代码如下:

#奇数层没有shift_size 偶数层有 shift_size
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution#(56/(2^layer_index),56/(2^layer_index))
            #zero_init:img_mask (1,H,W,1)
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            #h_slices :(slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))
            #>>> c=range(0, 10)
            #>>> c[h_slices[0]]
            #   range(0, 3)
            #>>> c[h_slices[1]]
            #   range(3, 7)
            #>>> c[h_slices[2]]
            #   range(7, 10)
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            #将一个H*W的输入按照切片分为9块
            #按照H维进行切片
            for h in h_slices:
                #按照W维进行切片
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

           #将img_mask shape 1,H,W,1-> nW, window_size, window_size, 1
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            #nW, window_size, window_size
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            #attn_mask:[nW, window_size * window_size, window_size * window_size]
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            #矩阵中为0的置0 不为0的置 -100
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

关于程序中的数组切片,slice 部分,代码注释中有说明,注意 这里 window_size=7, shift_size=3,就不详细说明了,这里先针对 img_mask 来说明下,即下图的步骤,具体完成了哪些内容?在例子中,我让img_mask的 H=W=14
在这里插入图片描述
首先 img_mask 为和输入大小一致的张量,经过上面的slice代码的切片后,则形成了下面形状(1,14,14,1)的张量
在这里插入图片描述

注意到
如果直接将img_mask转为(14,14)的张量,我们可以看到其形状,相当于将张量根据slice切片,分为了9个部分,其中红色部分,不论img|_mask的 H W 为多少,始终为矩形,且大小为 (H-window_size)*(W-window_size),其余黄色的框基本大小是确定的,和 window_size 和 shift_size有关系。

在这里插入图片描述
其中attn_mask代码中的 mask_windows - > attn_mask的变换是关键的一步,这一步主要是让以 7*7 为单元的窗口中,块索引值相同的位置,置0,不同的位置 置为 -100 即直接屏蔽掉。

在这里插入图片描述

我们可以用以下代码模拟下,比如,a和b为shape为[2,3]的张量,则可以发现,a.unsqueeze(1) shape 为 [2,1,3], b.unsequeeze(2).shape 为 [2,3,1] ,最后经过 c = a.unsqueeze(1) - b.unsequeeze(2),c 变为了 shape [2,3,3],可以根据图中的计算过程,发现其实是 a [2,1,3] 中 b [2,3,1],就是 a 第一个 [1,3] 和 b 中 第一个[3,1] 中的每个元素进行减法操作,形成一个[3,3]的矩阵,然后再让 a 第二个 [1,3] 和 b 中 第二个[3,1] 中的每个元素进行减法操作,形成二个[3,3]的矩阵,最终形成了 [2,3,3] 的矩阵。
即经过上面的 mask_windows - > attn_mask 的运算,可以将不同窗口中,对应位置的索引相同值置0,不同值为两者的差值。

在这里插入图片描述
然后根据下面的代码进行同则赋 0,异则赋 -100。

#矩阵中为0的置0 不为0的置 -100
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

最后将得到的 attn_mask 与 得到的特征图 attn 进行相加

 #mask 在某阶段的奇数层为None 偶数层才存在
        if mask is not None:
            #nW=B*H/7*W/7 
            #mask.shape:[B*H/7*W/7 , 49, 49]
            nW = mask.shape[0]
            #mask:torch.Size([1, 4, 1, 49, 49])
            #attn.view:[B_ // nW, nW(4), self.num_heads, N, N]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)

我们用程序模拟下面这个步骤,即假设 attn.view:[B_ // nW, nW(4), self.num_heads, N, N]=[1,2,2,2,2],而mask.unsqueeze(1).unsqueeze(0).shape=[1,2,1,2,2] , 如下面的图
在这里插入图片描述
在这里插入图片描述
所以根据代码 展开 attn_mask的计算过程,可以用图示表示:在这里插入图片描述
通过图示可以发现,相当于强行将某些模块的样本用来计算对应mask的注意力值,这个属于对网络的一种约束了。且是强行分了 B_/nw 个模块,每个模块中交替进行计算对应那几个(nw)个mask的注意力。

说完了 attn_mask,再来看看 shift windows的操作,具体来讲,应该是一个特征图循环移位的操作,不过只移动了一次,所以直接用 shift 也可以理解。相关代码如下:

进行移动窗口

x = x.view(B, H, W, C)
# cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

由于代码中,默认 fused_window_process 为 False,所以进行移动窗口主要代码是:

#这里的 x = x.view(B, H, W, C)
if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

为与上面的例子对应,这里我们假设 shift_size = 3,由于 X的 shape为 [B,H,W,C] 所以可以看出,这个移位是在 H 和 W的维度分别移动 3
在这里插入图片描述
最后的shift恢复过程,就是上面 roll 和 partition 的反过程,代码中:

# reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
最后再来看一下 SwinTransformerBlock 的前向传播代码,即如果 shift_size>0 ,整体过程是对输入的整个特征图进行 循环移位 - > 然后进行带mask的注意力机制计算(SW-MSA)->再进行一系列后操作,这里并看不到针对某个窗口进行特征图移位和针对某个窗口进行 mask 均是针对整张特征图进行的相关操作。
def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

所以看了代码,不禁怀疑网上有些关于swintransformer的教程,其实有些问题的,具体还是要看代码,如果有问题,欢迎留言指正!但为什么swintransformer仅仅进行了特征图循环移位和限制性的mask注意力机制,就有效果,其实还需要深究,个人感觉是多阶段连续后,其实特征图的循环移位,让越靠后的层,越考虑了全局的特征,这个还需要再看看代码了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值