Swin Transformer
解决ViT对于下游任务不友好的问题,提出了滑动窗口
Swin的特点:
-
从小Patch开始,逐层合并相邻Patch
-
计算Window Attention
-
提出Shifted Window操作,更有效计算Attention
1.论文阅读笔记
Swin Transformer中使用移动窗口构建了层级式的ViT,让ViT像CNN一样也能够分成几个block,能做层级式的特征提取,对图像大小具有线性计算复杂度。
1.1 摘要:
指出问题Transformer从NLP用到视觉中的问题:
-
尺度太大(街景中有行人和汽车,有各种各样的尺寸,但在NLP中不存在)
-
resolution太大导致序列太长,计算量很大
之前解决办法:
-
用后续特征图作为Transformer的输入
-
将图片打成Patch,减少图片的resolution
-
将图片划成一个个的小窗口,在窗口中做自注意力。
本文提出移动窗口,不仅减少了计算量,因为有移动操作,能让相邻的两个窗口之间有了交互,于是上下层之间有了cross-window的连接,这种层级式结构的好处不仅灵活,提供各个尺度的信息,同时自注意力在小窗口内算的,计算复杂度随着图像大小而线性增长的。
1.2 结论:
使用小窗口内计算自注意力,而不同与ViT在整图上算自注意力,只要窗口大小是固定的,那SA的复杂度就是固定的,整张图的计算复杂度就随着图像大小呈线性增长关系,图像变为x倍,窗口数量就增大x倍,复杂度就是x倍,而非x的平方。
利用的CNN中局部性的归纳偏置,同一个物体的不同部位(语义相近的不同物体),还是大概率会出现在相连的地方,即使在小范围的窗口中也是够用的,全局中算注意力可能是浪费资源的。
CNN中多尺寸的特征是由池化操作,能够增大卷积核能看到的感受野,从而使得每次池化过后的特征抓住物体的不同尺寸,本文提出了patch merging,使得相邻的patch合成一个大patch,使得感受野增大,抓住多尺度特征。有了4×,8×,16×这种多尺度特征,扔给FPN就能做检测了,扔给UNET就能做分割了,所以Swin Transformer能当作骨干网络来做。
划分完后,窗口与窗口之间可以互动
假如一张图片224×224×3,首先打成4×4的patch,每个patch中图像大小变为原来1/4,即56,维度就变为4×4×3;接下来Linear Embedding把向量维度变为预先设置好的值(Swin Transformer能接受的值),超参数C=96,走完就变为56×56×96,拉直后就变为3136 ×96,而ViT中序列长度是16×16,而此时是3136非常大,本文是基于窗口来算的,窗口中仅有7×7=49个patch,暂且将它当作黑盒,在其中做了自注意力操作,如果不对其做约束,输入输出的尺寸是不变的,即输出还是56×56×96。
在Patch Merging中使用两个1×1的卷积来降维,将通道数由4C变为2C,目的是:与池化层相同,使得图像大小翻倍,通道数减半。
2.Swin Transformer架构分析
类似ViT架构,对于输入图像进行Patch Partition(图像分块)、Patch Embedding,然后经过4个stage,类似于ResNet中的Stage,在Stage中主要由 Swin Transformer Block 构成,最后进行一个Patch Merging进行融合。
最关键的操作就是Swin Transformer Block和Patch Merging
上图展示了模型的大致结构图,我们还可以关心一下数据的流动
将图中拆分来看
1.以patch输入到网络后,如果是彩色图像,它的channel就为3,经过Patch Embedding后,通道数就变为embed_dim
2.得到Patch Embedding后,再使用 窗口(Windows)再切一次patch,我们目前的输入已经是feature level的tensor,做一次Windows Partition切成不重叠的窗口
3.如果没有划分window,我们做的就是每一个patch与其他所有batch,现在划分后,现在在每一个窗口中单独去做即可,可以减少计算量,不需要算每个batch与其他batch,我们经过attention后,输出维度和输入一样,这样对每个单独的window做完后,最终维度同样尺寸大小的tensor
4.Patch Merging
swin transformer中就是对相邻的4个image token融合起来,空间上尺寸变小,同时会将embed_dim的维度扩大2倍
5.Next Stage
一个Stage部分做完后,走到了下一个Stage,此时输入是merge之后更小的输入,继续重复上述步骤,切window,减少尺寸,升高维度
在某些Stage中,重复做多次Block块,但是其中不会改变尺寸,输入输出维度一直不变
2.1 Swin Transformer Block
本部分介绍Block是如何构建的
由W-MSA(Window Multi-head Self Attention)和SW-MSA(Shifted Window Multi-head Self Attention)构成,本文先不介绍移动窗口部分,仅看左侧如何处理,数据进入后通过LN层,然后再走W-MSA,进行残差连接,再走LN,MLP,残差融合,与之前的差不多,就是需要修改W-MSA部分
W-MSA
Tensor通过划分Window操作后,只拿自己的Window出来,把其中16个token提出来,来做Attention;之后再拿window2自己的16个token做attention,每一个分开做
论文说W-MSA比MSA计算量小,推算一下公式
可以看到两个公式第二项不一样,MSA中随着h·w尺寸/(patch_num)呈平方级增长,但在W-MSA中呈线性关系,如果图像切图像尺寸越小,使用W-MSA效率会更高。
2.2 Patch Merging
将同一个window中画的颜色不同四个部分,排到一起,将每一部分的tokens再并列,原来merge后得到小window的dim会变为原来的4倍,此时token的数量会变为原来的1/4,之后在做一次,将其映射到2倍。
最后对映射后的内容reshape回去,于是长宽均变为原来1/2,维度变为2倍
3. 代码实现
涉及W-MSA 和 Patch Merging 以及 Window Partition
Window Partition是将我们的tensor切成window,然后送到attention中去算,所以有三个QKV。假设我们一个batch有3个样本,每个样本尺寸是一样的,都要把红框的window切出来,每个window单独做attention。
可以把一个batch的所有小window拼到一起去,所有的window直接是没关系的,window只管自己的,不管怎么排列,我只算自己的
我们看到的一个方块的一个小格,要它与算窗口内其他所有格的attention,这叫做window_attention,然后将每个小窗口的16个token拉出来,展开,即
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self,patch_size=4,embed_dim=96):
super().__init__()
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size,stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.patch_embed(x) #[n, embed_dim, h', w']
x = x.flatten(2) #[n, embed_dim, h'w']
x = x.permute(0, 2, 1) # [n, h'*w', embed_dim]
x = self.norm(x)
return x
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim):
super().__init__()
self.resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear( 4 * dim ,2 * dim)
self.norm = nn.LayerNorm(4 * dim)
def forward(self, x):
h, w = self.resolution
b, _, c = x.shape # _ 不用,其是 num_patches,即 h*w
x = x.reshape([b, h, w, c])
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 1::2, 0::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.concat([x0, x1, x2, x3], axis=-1) # [B, h/2, w/2, 4c]
x = x.reshape([b, -1, 4*c])
x = self.norm(x)
x = self.reduction(x)
return x
class Mlp(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, dropout=0.):
super().__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio),dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
def windows_partition(x, window_size):
B, H, W, C = x.shape
x = x.reshape([B, H//window_size,window_size, W//window_size, window_size, C])
x = x.permute([0,1, 3, 2, 4, 5])
# [B, h//ws, w//ws, ws, ws, c]
x = x.reshape([-1, window_size, window_size, C])
# [B * num_patches, ws, ws, c]
return x
def windows_reverse(windows, window_size, H, W):
B = int(windows.shape[0]// (H/window_size * W/window_size))
x = windows.reshape([B, H//window_size, W//window_size, window_size, window_size, -1])
x = x.permute([0, 1 ,3, 2, 4, 5])
x= x.reshape([B, H, W, -1])
return x
定义了WindowAttention,我们将它组合起来
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.dim_head = dim// num_heads
self.num_heads = num_heads
self.scale = self.dim_head ** -0.5
self.softmax = nn.Softmax(-1)
self.qkv = nn.Linear(dim,
dim * 3)
self.proj = nn.Linear(dim, dim)
def tranpose_multi_head(self, x):
new_shape = x.shape[:-1] + (self.num_heads, self.dim_head)
x = x.reshape(new_shape)
x = x.permute(0, 2, 1, 3) #[B, num_heads, num_patches, dim_head]
return x
def forward(self,x):
# x: [B, num_patches, embed_dim]
B, N, C = x.shape
qkv = self.qkv(x).chunk(3, -1)
q, k, v = map(self.tranpose_multi_head, qkv)
q = q * self.scale
attn = torch.matmul(q, k.transpose(-1,-2))
attn = self.softmax(attn)
out = torch.matmul(attn, v) # [B, num_heads, num_patches, dim_head]
out = out.permute([0, 2, 1, 3])
# # [B, num_patches, num_heads, dim_head] num_heads * dim_head= embed_dim
out = out.reshape([B, N, C])
out = self.proj(out)
return out
class SwinBlock(nn.Module):
def __init__(self, dim, input_reslution, num_heads, window_size):
super().__init__()
self.dim = dim
self.reolution = input_reslution
self.window_size =window_size
self.attn_norm = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size,num_heads)
self.mlp_norm = nn.LayerNorm(dim)
self.mlp = Mlp(dim)
def forward(self,x):
H, W = self.reolution
B, N, C =x.shape
h = x
s = self.attn_norm(x)
#切 window
x = x.reshape([B, H, W, C])
x_windows = windows_partition(x, self.window_size)
# [B * num_patches, ws, ws, c]
x_windows = x_windows.reshape([-1,self.window_size*self.window_size, C])
attn_windows = self.attn(x_windows)
# 做完attention 将它复原
attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])
x = windows_reverse(attn_windows, self.window_size, H, W)
# [B, H ,W ,C]
# 但是做mlp中 输入不是它
x = x.reshape([B, H*W, C])
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = h + x
return x
最终用一个主函数去调用
def main():
t = torch.randn([4, 3, 224, 224])
patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
swinBlock = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7)
patch_merging = PatchMerging(input_resolution=[56,56], dim=96)
out = patch_embedding(t) #[4, 56, 56, 96]
print('path_embedding out shape= ',out.shape)
out = swinBlock(out)
print('swinBlock out shape= ',out.shape)
out = patch_merging(out)
print('patch_merging out shape= ',out.shape)
if __name__ == '__main__':
main()
首先我们输入的是一个batch的data,[4, 3, 224, 224],batch_size为4,我们通过patch_embedding操作,取一定大小的patch,patch_size 为4,所以变换完后tensor变为[4,56,56,96],3136是为了下一步做attention的时候方便。
在swinBlock,其中是做了windows_partition、WindowAttention,其中是不变换维度大小的
最后做了patch_merging,类似池化,相邻的4个token合并,维度扩大了2倍,即96变为192,而784是28×28,即56×56缩小了两倍
所以WindowAttention主要也是reshape再变回去