6.一脚踹进ViT——Swin Transformer(下)
一、Shift-Window Attention
image输入后提一个patch块,如果是4×4=16个pixel,就把它展开,来做conv/proj,经过处理后它就变成1×embed dim的tensor, 每一个就叫做embedding或vision tokens,真正输入Transformer的就是embedding或vision tokens,我们在此基础上需要进行切分window
上一节我们就是这样的进行window的划分,我们上图划了四个窗口,每一窗口里面一个小的格子就是embedding或vision tokens,Swin Transformer的第一步就是对单个窗口来做self-attention,但有一个问题,Window之间没有交互,那我希望每个token能看到gloab的信息应该怎么做呢?
最简单的办法可以类似卷积一样,画一个window,在某一个window大小上做一个滑窗操作,在该window中做self-attention
Swin 论文提出了 Shift-Window Attention,之前是每个颜色单独算自己的attention,与其他颜色无关,为了改变这种情况,Swin在算完一次Window的Self-attention之后,更换了划分window的方式,这样划分完之后,窗口内的注意力就有了其他颜色的token
所以我们的Task:对每个Window单独计算WMSA,可以单独做,但我们要找一个高效的方式来做
首先看一下位移怎么做,1<<3 ,01变成1000,二进制变为十进制就是8,那这种方式就是位移的操作,可以用numpy.roll,tensor.roll来实现,其中用到的是循环填充。
将将切过的图重新标号,第一次向右shift A的宽度,再往下shift
从下到上 E FD HB IGAC,对于每个window都是M×M大小的,对于每个窗口,进行 Q 和 K的计算得到
如果只算E的,和之前的一样,如果算F D,对于左半边F,不需要管右边D的内容,需要将它遮住,如果计算右边,也同样不需要看到左侧,将其遮住或者置为0。
所以我们需要找个 mask,把不需要的地方遮挡住,对于shift window的四方格,需要对每个window中找到对应的mask。
此时我们不需要管颜色,只需要管红框,第一个不需要设置Mask
对于每个token需要展开,这样设置mask就可以计算单个红框内的值,只把红色和红色,红色和黄色的结果保留,而红色和蓝色结果就舍去了,这样设计mask左部分F的self-attention就可计算出来
同理,右侧mask这样设计
所以这一部分的mask需要这样设计
同理对于接下来的M×M窗口如何设计呢? 上方同样只关注绿 黄
下方只关注蓝 红
合并起来,即
最后一部分每个颜色为一个小窗口,自己算自己的
四个放到一起为
这样设计完后,我们怎么进行mask呢?
需要的部分是0,不需要的地方选了-100,因为softmax中x越小,得到的值越接近于0,越不影响attention
因为我们之前为了高效计算,进行了cycic shift,但是如果我们要最后计算的时候,要恢复原来的shape,需要反向cycic shift
二、Relative Position Bias
有一个tensor,尺寸是3×3,其中有9个vision tokens,将其拉平,然后得到Q×K’后 9×9的attention Weight的矩阵,这时我们需要加一个B
那这个B干什么的呢,它是一个偏置,可学习的bias,加一个相对位置的bias,希望把相对位置信息给attention weight的信息加进去。
更详细的说,如果有9个token,那它有几个相对位置呢?
我们只看同一行用相同颜色的表示,竖直方向位置变化就是不变、+1、+2、-1、-2,那同一行位置变化就是0,那么黄色比蓝色位置多1
对于横向来看,位置偏置如图,
那我们合起来会怎么样呢?
这个就是针对于横向和纵向的偏置,它一共有25种相对位置,我们可以用更少的数量表示,用索引标记
我们有这样一个可学习的Table,最后得到相对位置的索引之后,我们查询Table将其填入 Relative Position Bias中,与Attention weight相加。
这里需要注意的是尺寸的一个问题
假如我们的window为 M×M,那我们Q×K’ 为 M²×M²,我们的Relative Position Bias Table为(2M-1)×(2M-1)。
实现的话,我们可学习的Relative Position Bias Table,创建一个parameter,shape为2倍的(window_size-1)× 2倍的(window_size-1),那我们的头可以是单独的也可以是公用的,Swin中给每个头单独来做。
设置完后还需要去注册一下,将值注册到层里面
三、实践部分:主要是实现Shift-Window-Multihead-Self-Attention
1. img_mask部分
假设我们shift已经做完了,我们现在有十字格,在算mask之前,要算image mask,建立mask.py文件
有了image mask之后,我们会将它切分然后展开,看下方,针对于第四行,我们最后4578这部分,我们的目标就是这个,最终就能生成左下角的Attention Mask
首先就是生成 img_mask的代码
我们使用 window_size大小,然后slice进行划分,第一个h 从(0, -window_size),w从(0, -window_size),就把第一个window切出来,进而切下一个,w变为(-window_size,-shift_size),就把1 切出来了,经过这样操作之后,就把图中彩色部分变为黑白部分。
接下来就要做window_partition,按照十字刀来切,要对每一个四方格块来算,而每个块中的不统一,我们要经过mask来得到自己想要的部分,但第一步还是进行切分
切分后,给他把每一个window 拉平,把[1,H,W,1]的最后一维去掉,然后把H和W拉平
attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2)
# [n,1, ws*ws] - [n, ws*ws, 1]
此时我们进行相减,在做什么呢?在Numpy中如果有一个4行一列的向量和一行三列的向量做加法,最终得到的是四行三列的矩阵,他会对于维度进行扩充,这就叫broadcasting
那我们相减操作在干什么呢?我们将展开的4578,扩充为列向量和行向量之后,进行相减,得到的是0的部分,我们可以将不是0的部分设置为255,图中是蓝色区域这就是我们要的atten_mask,将其返回即可
代码如下
# TODO: generate attn mask
def generate_mask(window_size=4, shift_size=2, input_resolution=(8,8)):
H,W = input_resolution
img_mask = torch.zeros([1, H, W, 1])
h_slices = [slice(0, -window_size), #a[slice(..)] = a[0: -window_size]
slice(-window_size,-shift_size),
slice(-shift_size,None)]
w_slices = [slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None)]
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w,:] = cnt
cnt += 1
windows_mask = windows_partition(img_mask, window_size= window_size)
# 考虑尺寸需要reshape,拉平
windows_mask = windows_mask.reshape([-1, window_size*windows_mask])
attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2) # [n,1, ws*ws] - [n, ws*ws, 1]
attn_mask = torch.where(attn_mask!=0,
torch.ones_like(attn_mask)*255,
torch.zeros_like(attn_mask)) #如果是0,不管,不是0,设置为255
return attn_mask
在main函数里调用,来输出展示
def main():
# TODO: main
mask = generate_mask()
print(mask.shape)
mask = mask.cpu().numpy().astype('uint8')
for i in range(4): #4个子window
for j in range(16): # 16*16的尺寸
for k in range(16):
print(mask[i,j,k], end='\t')
print()
im = Image.fromarray(mask[i, :, :])
im.save(f'{i}.png')
print()
print()
print()
2. 实现shift window self window
从上一节课的main 出发,新建一个Swin_Transformer_add_shift,加入swin_block_sw_msa,对于SwinBlock进行一个修改,它做的其实就是图中的流程
这里做了两次变换,第一次是 循环移位,第二次是循环移位后,切window,展开token做注意力,做完之后将其reverse,再将窗口复原回去。
SwinBlock中 forward代码如下
class SwinBlock(nn.Module):
def forward(self,x):
H, W = self.reolution
B, N, C =x.shape
h = x
s = self.attn_norm(x)
#切 window
x = x.reshape([B, H, W, C])
##### Begin
# TODO: shift window
if self.shift_size > 0:
# 先向右再往下挪
shifted_x = torch.roll(x, shifts=(-self.shift_size,-self.shift_size), axis=(1,2))
else:
shifted_x = x
# TODO: compute window attn
# 切方块
x_windows = windows_partition(shifted_x, self.window_size)
# 将其展开序列
x_windows = x_windows.reshape([-1, self.window_size*self.window_size,C])
attn_windows = self.attn(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size,C])
shifted_x = windows_reverse(attn_windows, self.window_size)
# TODO: shift back
if self.shift_size > 0:
# 先向右再往下挪
x = torch.roll(shifted_x, shifts=(self.shift_size,self.shift_size), axis=(1,2))
else:
x = shifted_x
##### End
x_windows = windows_partition(x, self.window_size)
# [B * num_patches, ws, ws, c]
x_windows = x_windows.reshape([-1,self.window_size*self.window_size, C])
attn_windows = self.attn(x_windows)
# 做完attention 将它复原
attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])
x = windows_reverse(attn_windows, self.window_size, H, W)
# [B, H ,W ,C]
# 但是做mlp中 输入不是它
x = x.reshape([B, H*W, C])
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = h + x
return x
这里我们调用了attn_mask,所以在 init中初始化,代码如下:
def __init__(self, dim, input_reslution, num_heads, window_size,shift_size):
super().__init__()
self.dim = dim
self.reolution = input_reslution
self.window_size =window_size
self.shift_size = shift_size
self.attn_norm = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size,num_heads)
self.mlp_norm = nn.LayerNorm(dim)
self.mlp = Mlp(dim)
# TODO: generate mask and register buffer
if self.shift_size > 0:
attn_mask = generate_mask(self.window_size, self.shift_size, self.reolution)
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
最后一部分,我们在WindowAttention内,我们将mask传入了,要将mask加进去,来做attention
我们要将Mask加到原图中,来算attn,其中最重要的就是attn与mask维度不一致,需要进行变换reshape,unsqueeze等等
class WindowAttention(nn.Module):
def forward(self,x,mask=None):
# x: [B, num_patches, embed_dim]
B, N, C = x.shape
print('xshape=', x.shape)
qkv = self.qkv(x).chunk(3, -1)
q, k, v = map(self.tranpose_multi_head, qkv)
q = q * self.scale
attn = torch.matmul(q, k.transpose(-1,-2))
# [B * num_windows, num_heads, num_patches, num_patches]
print('attn shape=',attn.shape)
##### Begin: Mask
# TODO: reshape and add mask if mask is not none
if mask is None:
attn = self.softmax(attn)
else:
# mask : [num_windows, num_patches, num_patches]
# attn: [B * num_windows, num_heads, num_patches, num_patches]
attn = attn.reshape([B//mask.shape[0],mask.shape[0],self.num_heads, mask.shape[1], mask.shape[1]])
# attn: [B, num_windows, num_heads, num_patches, num_patches]
# mask : [ num_windows, 1, num_patches, num_patches]
attn = attn+ mask.unsqueeze(1).unsqueeze(0)
# mask : [1, num_windows, 1, num_patches, num_patches]
attn = attn.reshape([-1,self.num_heads,mask.shape[1],mask.shape[1]])
# attn: [B * num_windows, num_heads, num_patches, num_patches]
##### End: Mask
out = torch.matmul(attn, v) # [B, num_heads, num_patches, dim_head]
out = out.permute([0, 2, 1, 3])
# # [B, num_patches, num_heads, dim_head] num_heads * dim_head= embed_dim
out = out.reshape([B, N, C])
out = self.proj(out)
return out
运行结果如下:
完整代码之后会加入资源中,目前没有实现 Relative Position Bias,其他全部实现了!