Swin Transformer Block
这部分是整个程序的核心,包含了论文中的很多知识点,涉及到相对位置编码、mask、window self-attention、shifted window self-attention,也是我在阅读论文和源码花时间最多的地方。
上图为论文中Swin Transformer Block结构图,Swin Transformer使用window self-attention降低了计算复杂度,为了保证不重叠窗口之间有联系,采用了shifted window self-attention的方式重新计算一遍窗口偏移之后的自注意力,所以Swin Transformer Block都是成对出现的 (W-MSA + SW-MSA为一对) ,不同大小的Swin Transformer的Block个数也都为偶数,Block的数量不可能为奇数。
首先介绍一下相对位置编码:
Attention(Q,K,V)=SoftMax(QKT/d+B)V
绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数。
假设window_size = 2*2即每个窗口有4个token (M=2) ,如图1所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图6所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。
图6 相对位置索引求解流程图
图6最后生成的是相对位置索引,relative_position_index.shape = (M2∗M2) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1)∗(2M−1) ,所以相对位置编码可以由一个3*3的矩阵表示,如图7所示:
图7 相对位置编码
图7中的0-8为索引值,每个索引值都对应了 M2 维可学习数据(根据图1,每个token都要计算 M2 个QK值,每个QK值都要加上对应的相对位置编码)
继续以图6中 M=2 的窗口为例,当计算位置1对应的 M2 个QK值时,应用的relative_position_index = [ 4, 5, 7, 8] (M2)个 ,对应的数据就是图7中位置索引4,5,7,8位置对应的 M2 维数据,即relative_position.shape = (M2∗M2)
相对位置编码在源码WindowAttention中应用,了解原理之后就很容易能够读懂程序:
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim # 输入通道的数量
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1]) 维度=Wh
coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1]) 维度=Ww
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
'''
后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x\y坐标相加转换为一维偏移的时候
他们的偏移量是相等的,所以需要对其做乘法操作,进行区分
'''
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
# 计算得到相对位置索引
# relative_position_index.shape = (M2, M2) 意思是一共有这么多个位置
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
'''
relative_position_index注册为一个不参与网络学习的变量
'''
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
'''
使用从截断正态分布中提取的值填充输入张量
self.relative_position_bias_table 是全0张量,通过trunc_normal_ 进行数值填充
'''
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
N: number of all patches in the window
C: 输入通过线性层转化得到的维度C
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
'''
x.shape = (num_windows*B, N, C)
self.qkv(x).shape = (num_windows*B, N, 3C)
self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
'''
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
'''
q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
N = M2 代表patches的数量
C//num_heads代表Q,K,V的维数
'''
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# q乘上一个放缩系数,对应公式中的sqrt(d)
q = q * self.scale
# attn.shape = (num_windows*B, num_heads, N, N) N = M2 代表patches的数量
attn = (q @ k.transpose(-2, -1))
'''
self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
self.relative_position_index是计算出来不可学习的量
'''
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
'''
attn.shape = (num_windows*B, num_heads, M2, M2) N = M2 代表patches的数量
.unsqueeze(0):扩张维度,在0对应的位置插入维度1
relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
num_windows*B 通过广播机制传播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的维度1会broadcast到数量num_windows*B
表示所有batch通用一个索引矩阵和相对位置矩阵
'''
attn = attn + relative_position_bias.unsqueeze(0)
# mask.shape = (num_windows, M2, M2)
# attn.shape = (num_windows*B, num_heads, M2, M2)
if mask is not None:
nW = mask.shape[0]
# attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
# mask.unsqueeze(1).unsqueeze(0).shape = (1, num_windows, 1, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
# broadcast相加
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
# attn.shape = (B, num_windows, num_heads, M2, M2)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
'''
v.shape = (num_windows*B, num_heads, M2, C//num_heads) N=M2 代表patches的数量, C//num_heads代表输入的维度
attn.shape = (num_windows*B, num_heads, M2, M2)
attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
'''
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # B_:num_windows*B N:M2 C=num_heads*C//num_heads
# self.proj = nn.Linear(dim, dim) dim = C
# self.proj_drop = nn.Dropout(proj_drop)
x = self.proj(x)
x = self.proj_drop(x)
return x # x.shape = (num_windows*B, N, C) N:窗口中所有patches的数量
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
在上述程序中有一段mask相关程序:
if mask is not None:
nW = mask.shape[0]
# attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
# mask.unsqueeze(1).unsqueeze(0).shape = (1, num_windows, 1, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
# broadcast相加
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
# attn.shape = (B, num_windows, num_heads, M2, M2)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)