0、前言
最近几天看了Swin-Transformer这篇论文,在看代码时对其中的掩码机制不解,尤其是看不懂代码的理解,而Swin的掩码机制又是论文的亮点之一,在查阅各方资料后终于弄懂了原理。
1、什么是掩码机制?
1.1滑动窗口机制(Shift windows)
为了理解什么是掩码机制,我们需要知道为什么需要掩码机制,这就是因为Swin的滑动窗口的原因。
图1.1.1
众所周知,Swin的自注意力机制是基于窗口的自注意力机制,如图1.1.1所示,而基于窗口的自注意力机制意味着窗口和窗口之间的联系消失了,这其实有悖于Transformer结构的全局自注意力机制,也丢掉了Transformer的最大优势。因此Swin的作者在这里又提出了一种自注意力机制。
图1.1.2
Swin中的另一种自注意力机制就是基于滑动窗口的自注意力机制,如图1.1.2所示。
图1.1.3:左位未经滑动的原图,右图经过滑动的结果图
图1.1.3左图指代未经过滑动,图片一共有四个窗口,分别在这四个窗口内做自注意力机制。图1.1.3的右图指代经过滑动窗口后,还是四个窗口内做自注意力,但这四个窗口已经与图左的四个窗口有所区别。从1.1.3可以看出,滑动指的就是将图片的左面一部分和上面一部分分别移动到图片的右面和下面,移动的大小即为Shift-Size。
图1.1.4
图1.1.4又将四个窗口分为了9个小窗口(这里我们为与图1.1.3区分,采用(数字)的形式来表示),窗口(2)即是从左边的B移过来的。这里为什么要分为9个小窗口呢,我们来一一说明。
- 首先我们讨论1.1.3的窗口1(1.1.3右图),它在图1.1.4中被分为窗口(0),这个窗口在滑动后可以直接做自注意力。
- 接着我们讨论1.1.3的窗口2(1.1.3右图),它在图1.1.4中被分为窗口(1)和(2),(1)和(2)是不可以做自注意力的,因为在原图中,(2)来自于左边的B,(2)于(1)在原图中是不相邻的,所以不应该做自注意力。
- 1.1.3的窗口3被分为(3)和(6),它两也是不该做自注意力的。
- 1.1.3的窗口4被分为(4),(5),(7),(8),它们之间不该做自注意力机制。
基于上面的问题,作者希望提出一种掩码机制,使得滑动后该做自注意力的做,不该做自注意力的不做。
2、掩码机制的具体实现
2.1作者解释
实际上有非常多的人对掩码机制不解,因此GIt上就有人就这个问题向作者提出了疑问,作者也是给出了非常精彩的回答,这里我们贴出地址:https://github.com/microsoft/Swin-Transformer/issues/38
2.2代码解读
作者贴出的回答其实已经够详细了,但是为了以防有些同学还是不够理解,我们做一下代码解读,代码为作者在回答中贴出的代码,我将其粘贴在这里:
import torch
import matplotlib.pyplot as plt
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 7
shift_size = 3
H, W = 14, 14
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
plt.matshow(img_mask[0, :, :, 0].numpy())
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())
plt.show()
同时我将结果也贴出来:
图2.2.1
这里我们直接看代码,作者首先定义了几个变量,窗口大小、滑动步长、图片大小。之后,作者生成了一个(1,h,w,1)的img_mask,这里维度对应(B,h,w,C),即Batch_Size,图片大小,C个通道,这里是为了和正式项目中的维度对应,可以理解成生成一个跟图片大小相同的img_mask。
接下来作者定义了h_slices和w_slices两个切片,就是在h和w两个方向上分别将图片分为三块,合起来就是九块,这里是用两个for循环实现,对不同的部分赋予不同的值,最后的结果如下所示:
图2.2.2
图2.2.2中的0-8既指代9个小窗口,同时也表示每一个窗口内赋予的值。Window0-3指大窗口。这一点我们在前面有所提及。
图2.2.3
这里为了方便显示,我们去掉两个维度显示img_mask(即图2.2.3红色方框所做的操作),如图2.2.3所示,可以看到 img_mask被分为9个部分同时每个部分被赋予不同的值。
接着img_mask在通过window_partition方法后,维度为[4,7,7,1],即是14*14的图片被分为四个长宽各为7的小窗口,接着维度进一步转换为[4,49],这里都是常规化的操作。
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
最关键的代码在这里,这里mask_windows的维度为[4,49],代表一共四个窗口,每个窗口内有7*7个值。unsqueeze代表增加维度,这里其实使用了pytorch的广播机,mask_windows.unsqueeze(1)变为[4,1,49],mask_windows.unsqueeze(1)变为[4,49,1],两者相减变为[4,49,49]。
这里我们定义俩个临时变量来表示:
temp1=mask_windows.unsqueeze(2)
temp2=mask_windows.unsqueeze(1)
temp1_1=np.repeat(temp1,49,axis=2)
temp2_2=np.repeat(temp2,49,axis=1)
temp1的维度为[4,49,1],temp2的维度为[4,1,49],两者相减相当于temp1在维度2上复制49次变为[4,49,49],temp2为维度1上复制49次,变为[4,49,49],这样就可以进行相减了。
这里的维度4代表四个窗口,我们以第二个窗口为例,取temp1[1]输出即为第二个窗口并展开:[1,1,1,1,2,2,2,.......,1,1,1,1,2,2,2],对应下图红色框展开:
图2.2.4:图2.2.3的窗口1
将temp1在维度2上复制49次,得到维度为temp1_1[4,49,49],我们这里输出temp1_1的第二个窗口即temp1_1[1]:
图2.2.5:temp1_1[1]图片显示不全,应该为[49,49],剩下自行脑补
这里我们只截取了部分,可以看到相当于一列复制了49次,得到的[49,49]的每一行内的值都是相等的,且每一行的值分别对应窗口1的全部49个值。
同样我们展开temp2_2的第二个窗口temp2_2[1]:
图2.2.6:temp2_2[1]
temp1_1相当于一个窗口内的49 个值在列的方向上复制49次,temp2_2相当于一个窗口内的49个值在行的方向上复制49次,因此这49行是相等的,每一行的值都相当于窗口1内的全部49个值。
经过这样操作后,temp1_1的[49,49]中的第[1,49]相当于窗口1的第一个像素的值复制了49次,[2,49]相当于窗口1的第二个像素的值复制了49次,以此类推。temp2_2的[49,49]的[1,49]相当于窗口1内的所有49个值,[2,49]也是这49个值。(这里不好理解,大家对照图2.2.5和2.2.6理解)
temp1_1-temp2_2相当于窗口内的每一个值都要与所有的值相减,我们记得前面我们给9个小窗口赋予了不同的九个值,所以这里相减的话,不在一个小窗口内的像素点一相减,就会变为非0,而在一个小窗口内的一相减,就会变为0。
其实到这里我们就实现了掩码机制,将非0值设为-100(一个非常小的负值),即认为不会注意。
3、总结
到这里我们已经实现了对Swin掩码机制的解释,有些讲的不好的地方希望大家理解。