Swin Transformer 是2021.3在ICCV发表的一篇论文,同时也是这一年ICCV的best paper,在各大检测、分割任务中有着非常出色的结果。
论文地址:https://arxiv.org/pdf/2103.14030.pdf
论文官方代码:https://github.com/microsoft/Swin-Transformer
目录
整体架构
整个模型采取层次化的设计,一共包含4个Stage,与ViT不同,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野,这样解决了多尺寸目标检测的问题,为本文的一大创新。
这里放一个Swin 和ViT的对比图,这其实可以看成作者对ViT的一个改进,ViT从头至尾都是对全局做self-attention,而swin-transformer是一个窗口在放大的过程,然后self-attention的计算是以窗口为单位去计算的,这样相当于引入了局部聚合的信息,和CNN的卷积过程很相似,就像是CNN的步长和卷积核大小一样,这样就做到了窗口的不重合,区别在于CNN在每个窗口做的是卷积的计算,每个窗口最后得到一个值,这个值代表着这个窗口的特征。而swin transformer在每个窗口做的是self-attention的计算,得到的是一个更新过的窗口,然后通过patch merging的操作,把窗口做了个合并,再继续对这个合并后的窗口做self-attention的计算。
Swin Transformer Block
可以注意到这4个stage的Swin Transformer Block都有×2或×6,是2的倍数,因为这个是两个Successive block组合的,两个block的不同在于一个是W-MSA(window multi-head self attention),另一个是SW-MSA(shifted window multi-head self attention),后面会分别详细讲解。
Patch Partition
首先,输入图像H×W×3,输入到Patch Partition模块,在代码中是PatchEmbed类实现的,我们来看一下PatchEmbed的forward()函数:
def forward(self, x):
"""Forward function."""
# padding 确保 H、W为patch_size的整数倍
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C H/4 W/4 C=embed_dim
if self.norm is not None: # 下面是一个normalization的一个操作
Wh, Ww = x.size(2), x.size(3) # B C H/4 W/4
x = x.flatten(2).transpose(1, 2) # (B, H/4×W/4, embed_dim)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
# (B, embed_dim, H/4 W/4)
return x # (B, embed_dim, H/4, W/4)
可以看到PatchEmbed的核心代码就是self.proj()函数,如下:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
x = self.proj(x)
而self.proj()是一个卷积函数,输出的通道数为embed_dim,卷积核的大小就是patch_size,这样下来,就是把输入图像通过一个4×4的卷积,H、W分别缩小了4倍。
Linear Embedding
在输入到Transformer网络之前,要把维度变为(B,n_tokens, embed_dim)的格式,才能进行接下来的multi-head self attention,对经典的transformer网络不熟悉的同学可以去把transformer看懂了再过来,这里就不讲了。其实这个代码很简单,输入维度是(B, embed_dim, H/4, W/4),这样我们的token的数目就是H/4 × W/4 = HW/16,而每一个token的维度就是我们初始化网络输入的embed_dim,这里有不同尺寸的Swin网络(Swin-T,Swin-S,Swin-B,Swin-L),论文里面参数设置如下:
- Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
- Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
- Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
- Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
这里的C就是 embed_dim,后面的layer numbers对应每个stage的Swin Transformer的数目。所以只需要一行代码就能实现维度转换(B,n_tokens, embed_dim)的操作:
x = x.flatten(2).transpose(1, 2)
当然, Linear Embedding还有一个位置编码的操作,论文这里用了absolute position embedding。代码里面绝对位置编码就是随机初始化参数,与x的Batch维度以外的相匹配,作为可学习的位置编码参数,如下:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
接着,x与位置编码相加,再进行flatten和transpose的操作,如下:
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # (B, H/4*W/4, embed_dim)
Transformer Block
Linear Embedding之后就输入到Swin Transformer Block,首先输入的是Windows的多头注意力机制Transformer Block;接着输入的是shifted Windows的多头注意力机制Transformer Block。
Window based Self-Attention
首先介绍Window based Self-Attention,这个Attention简单来说就是CNN版的ViT,把特征图分为7*7(假设设置的窗口大小为7*7)大小的窗口,对每一个窗口进行Linear Embedding的操作,然后像ViT一样,输入多头注意力机制网络,如图:
window_partition代码如下:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows (windows_num, window_size, window_size, embedding_dim)
然后把把每个小window的token展平:
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # (nW*B, window_size*window_size, embedding_dim)
接下来输入到W-MSA网络,和普通的Transformer网络一样,由(B, token_num, embedding_dim) 得到qkv,然后注意力矩阵attn=q · k.transpose,输出x=attn·v;值得注意的是,我们这里B=windows_num * Batchsize, token_num=window_size*window_size。另外更值得一提的是,这里的注意力矩阵attn是加上了相对位置编码的,后续论文的实验有证明相对位置编码提升了模型性能。WindowAttention 的forward函数:
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) #(B*windows_num, head_num, tokens_num, tokens_num)
# tokens_num = windows_size * windows_size
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
相对位置编码
这里面比较难理解的是相对位置编码是怎么定义的,我们可以在上面看到相对位置编码是与注意力矩阵attn相加的,这里atten的维度为(B*windows_num, head_num, tokens_num, tokens_num),我们要得到与此维度相同的相对位置编码矩阵;
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
很明显,这里用到了初始化window类的两个self变量--self.relative_position_bias_table和self.relative_position_index,先介绍前面relative_position_bias_table,
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
trunc_normal_(self.relative_position_bias_table, std=.02)
它是通过nn.Parameter初始化,然后变为方差std为0.02的正态分布数据,它是一个二维的tensor数组,shape[0]为(2*window_h-1)*(2*window_w-1),假设windows的高宽都为7,那么shape[0]=13*13=169;shape[1]为多头注意力head的数目。
重点是self.relative_position_index是怎么定义的,看名字知道它是一个位置索引数组。首先,在一个二维的Windows中,每一个token的位置是一个二维的坐标,但是进入transformer网络后,二维的token要展平成一维的,这就意味着需要把二维的相对位置距离转为一维的,以下是相对位置编码index的代码:
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
我们这里假设windows_h=windows_w=3,一步一步来看:
window_size = [3, 3]
coords_h = torch.arange(window_size[0]) # 0,1,2
coords_w = torch.arange(window_size[1]) # 0,1,2
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
## coords:
tensor([[[0, 0, 0],
[1, 1, 1],
[2, 2, 2]],
[[0, 1, 2],
[0, 1, 2],
[0, 1, 2]]])
上面是横坐标,下面是纵坐标,合起来看的话就是这样:
我们给每一个位置编号如左图
接下来flatten的操作:
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Wh
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 1, 2, 0, 1, 2, 0, 1, 2]])
上面是横坐标,下面是纵坐标
接下来,利用广播机制,分别在第一维,第二维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww
的张量
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
如上,我们可以看到这个两个tensor,右边的就是左边对应横坐标已经纵坐标的转置,两个tensor相减之后,就得到了每一个位置的分别相对于其他位置的距离,个人看这一方法有类似与求自相关矩阵。我们把结果打印出来看看:
tensor([[[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 0, 0, 0, -1, -1, -1, -2, -2, -2],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 1, 1, 1, 0, 0, 0, -1, -1, -1],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0],
[ 2, 2, 2, 1, 1, 1, 0, 0, 0]], 横坐标相对偏移
[[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0],
[ 0, -1, -2, 0, -1, -2, 0, -1, -2],
[ 1, 0, -1, 1, 0, -1, 1, 0, -1],
[ 2, 1, 0, 2, 1, 0, 2, 1, 0]]]) 纵坐标相对偏移
我们可以看到对角线都是0,这一点很好解释,3*3窗口中的9个token对自己的偏移都是0。好,我们把横纵坐标合在一起,如下:
tensor([[[ 0, 0],[ 0, -1],[ 0, -2],[-1, 0],[-1, -1],[-1, -2],[-2, 0],[-2, -1],[-2, -2]],
[[ 0, 1],[ 0, 0],[ 0, -1],[-1, 1],[-1, 0],[-1, -1],[-2, 1],[-2, 0],[-2, -1]],
[[ 0, 2],[ 0, 1],[ 0, 0],[-1, 2],[-1, 1],[-1, 0],[-2, 2],[-2, 1],[-2, 0]],
[[ 1, 0],[ 1, -1],[ 1, -2],[ 0, 0],[ 0, -1],[ 0, -2],[-1, 0],[-1, -1],[-1, -2]],
[[ 1, 1],[ 1, 0],[ 1, -1],[ 0, 1],[ 0, 0],[ 0, -1],[-1, 1],[-1, 0],[-1, -1]],
[[ 1, 2],[ 1, 1],[ 1, 0],[ 0, 2],[ 0, 1],[ 0, 0],[-1, 2],[-1, 1],[-1, 0]],
[[ 2, 0],[ 2, -1],[ 2, -2],[ 1, 0],[ 1, -1],[ 1, -2],[ 0, 0],[ 0, -1],[ 0, -2]],
[[ 2, 1],[ 2, 0],[ 2, -1],[ 1, 1],[ 1, 0],[ 1, -1],[ 0, 1],[ 0, 0],[ 0, -1]],
[[ 2, 2],[ 2, 1],[ 2, 0],[ 1, 2],[ 1, 1],[ 1, 0],[ 0, 2],[ 0, 1],[ 0, 0]]])
解释一下第一行是什么意思,[0,0]是指窗口里面的第一个token相对于自己,x轴的偏移与y轴的偏移都是0;[0,-1]是指window里的第2个token相对于第一个token,x轴的偏移是0,y轴的偏移是-1,以此类推···
那么第二行,[0,1]是指window里的第1个token相对于第2个token,x轴的偏移是0,y轴的偏移是1;[0,-1]是指window里的第3个token相对于第2个token,x轴的偏移是0,y轴的偏移是-1,以此类推···
这样,我们把x轴和y轴的偏移加起来,就是相对位置的总的偏移,当然,这里面x轴与y轴相加,会有负数的情况,代码里面为了不让出现负数,即相对位置偏移从0开始,进行了如下操作:
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
即,x与y都加上(3-1)=2
当然,代码还有一个操作,就是对所有的x偏移都乘上了(2 * window_size[1] - 1)这个数,这里window_size[1]=3时,就是5;
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
最后对x轴与y轴求和,就是最后相对位置偏移的index:
relative_position_index = relative_coords.sum(-1)
tensor([[12, 11, 10, 7, 6, 5, 2, 1, 0],
[13, 12, 11, 8, 7, 6, 3, 2, 1],
[14, 13, 12, 9, 8, 7, 4, 3, 2],
[17, 16, 15, 12, 11, 10, 7, 6, 5],
[18, 17, 16, 13, 12, 11, 8, 7, 6],
[19, 18, 17, 14, 13, 12, 9, 8, 7],
[22, 21, 20, 17, 16, 15, 12, 11, 10],
[23, 22, 21, 18, 17, 16, 13, 12, 11],
[24, 23, 22, 19, 18, 17, 14, 13, 12]])
拿对角线的12算,就是12 = (0+2)*5+(0+2)=12,其他位置的偏移也是这么算的。
前面我们得到了一个(2*window_h-1)*(2*window_w-1) ,head_num即(25,3)的relative_position_bias_table,上面的最大index就是24,即数组的第25个,正好对应上了。所以经过如下操作,就得到了相对位置编码:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
至于相对位置编码为什么要这样排列,一个Window里面只有9个patch,为什么要弄出这么一个9*9的位置矩阵?其实就是要跟attn矩阵对应起来,我们的attention矩阵就是由9个初始的patch分别得到9个q、k、v,然后q和k做相关矩阵,从另外一个角度讲就是自己跟自己做自相关,这一点不明白的可以先把transformer看明白,比如自相关矩阵第一行,就是每个patch分别与第一个patch的相关度,而第二行就是每个patch与第二个patch的相关度;所以这里的位置编码要这么排列,才能和attn矩阵的位置相对应,这样才能相加。
Shifted Window Attention
Shifted Window Attention 算是论文的特色创新点之一了,上面解释的Window Attention是在每个窗口下计算注意力的,为了更好的和其它窗口进行信息交互,Swin Transformer还引入了shifted window操作。
以论文中的图为例子,左图有4个window,每一个window有4*4个patch(就是token),左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。这里的shift的长度在代码中是这样的
shift_size=0 if (i % 2 == 0) else window_size // 2,
它的意思是偶数序号block 的shift_size=0,即进行普通的Window Attention;基数序号block 的shift_size=window_size//2,即窗口尺寸的一半(严谨点,window_size是偶数情况);但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。
特征图位移操作
如上图,论文图解是把右下角的A、B、C移到了左上,在代码里面是通过roll函数实现的:
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # # (B, H/4, W/4, embed_dim)
attn_mask = mask_matrix
图3和图1我们可以看到,它们有相同的数目的window,但是每一个window的内容却不一样。
Attention Mask
其实到这里,Shifted Window Attention已经很明朗了,把上面经过roll的图(上面图3)和W-MSA一样,做相同的操作就行了;但是,论文中有一个masked MAS,那么这个mask是什么功能。
我们先画一个经过roll之后的图,画一个大一点的方便理解:
如图8*8个大小为4*4的window,经过roll之后,在边界会出现这样的windows:
在边界的Window会同时具有来自于原图像(橙色、浅蓝色、褐色)的像素。但是由于这些window的不同颜色的像素,在源图像中两者本来就距离较远,所以在提取局部特征的时候,论文认为二者并不适合被认为互为邻域,不适合放在一起计算attention,为了结局这一问题,论文提出了mask 的Attention,我们来看看代码是这么做的。
拿2*2个大小为2*2的window的图像为例子,我们先对shift之后的区域编一个号:
先拿连续区域的1号区域为例子,看看attn矩阵是这么算的;首先我们把里面的4个patch展开,经过全连接得到q、k、v的3个矩阵,q 与k的转置进行矩阵相乘之后,得到atten矩阵:
然后第二个window的,将patch展开,q 与k的转置进行矩阵相乘之后,得到atten矩阵:
同理第3个window和第4个window的结果,如下:
上图中有颜色的区域是连续区域做相乘的,二没有颜色的区域,是在原图中相隔很远的区域相乘的结果,论文希望最后attn 与 v 相乘后,能够把没有颜色的区域忽略掉,只算相邻像素的自相关,于是代码设置了和attn同意大小的mask,有颜色区域设置为0,没有颜色区域设置为-100,在attention计算过程中,将得到的attn矩阵与mask相加,由于对该矩阵加了softmax所以相互对应位置均为0,也就是说两个不同的部分相互之间不会参与attention的计算,达成了隔离计算的目的。
tensor([[[[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]],
[[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]]],
[[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]]],
[[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]]]])
mask的代码如下:
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size # np.ceil(a) ,np.floor(a) : 计算各元素的ceiling 值, floor值(ceiling向上取整,floor向下取整)
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 (1,182,322,1)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
在代码里,第二个block的attn矩阵与mask相加如下:
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
至此,当第二个shift的W-MSA结束后,会有一个reverse shift的操作,就是把之前roll的像素,再roll还原;
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
PatchMerging
上面经过一个stage后,会有一个PatchMerging 的操作,其实就是downsample,将特征图缩小到原来的一半;废话不说,附上PatchMerging的forward代码:
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
画图形象一点,如下:
然后把得到的4张图cat到一起,就完成了下采样的操作,当然,得到的图的channel增加了4倍,后面经过view成一位的向量之后,通过全连接层,将维度缩小两倍,代码如下:
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
接下来就进入下一个stage,各部分功能代码是一样的,重复就是。
实验结果
论文代码的各项检查以及分割的指标非常好, 在ImageNet-1K达到了84.2%,在ImageNet-2K精度更是达到了86.4%,在分割任务ADE20K的mIoU达到了53.5%。
总结
这篇论文创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量,通过transformer得到了多尺度大小的特征图,方便了后续的分类以及分割的任务。相对位置编码部分,代码设计的也很是巧妙;在Shift Window Attention部分,用一个mask来进行非连续像素块的Attention矩阵的筛选,很是巧妙,论文非常推荐阅读,另外,文章制作不易,欢迎点赞收藏!