先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attention) 和 SW-MSA是关键组成部分。W-MSA出现在某阶段的奇数层,SW-MSA出现在某阶段的偶数层,W-MSA考虑的是单个窗口的信息,SW-MSA考虑的是不同窗口间的信息。
虽然从网络架构图里看,W-MSA和SW-MSA为两个不同的模块,但是在代码层面,两者是同一个代码片段,只是在计算SW-MSA时候,在计算完W-MSA后,然后通过代码进行滑动窗口,即cyclic shift操作,多计算了一个mask的操作。下面将针对代码进行分析。
W-MSA的代码
【注意】注释第一句话:Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.
代码注释中的中文,是以配置文件中 swin-tiny 相关的量 来进行注释的。
#窗口注意力
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim#96*(2^layer_index 0,1,2,3...)
self.window_size = window_size # Wh, Ww (7,7)
self.num_heads = num_heads#[3, 6, 12, 24]
head_dim = dim // num_heads#(96//3=32,96*2^1 // 6=32,...)
self.scale = qk_scale or head_dim ** -0.5#default:head_dim ** -0.5
# define a parameter table of relative position bias
#定义相对位置偏置表格
#[(2*7-1)*(2*7-1),3]
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
#得到一对在窗口中的相对位置索引
coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
#让相对坐标从0开始
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
#relative_coords[:, :, 0] * (2*7-1)
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
#索引值为 (49,49) 值在0-168对应位置偏移表的索引
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
#dim*(dim*3)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#attn_drop=0.0
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#初始化相对位置偏置值表(截断正态分布)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
#模块的前向传播
def forward(self, x,