令
input_resolution = (12, 12)
window_size = 6
shift_size = 3
生成部分的源码如下:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
# 对 [H, W] 大小进行分区
# 分区的目的在于,shift之后,进行window划分时,
# 一个window内包含多个区域,可能彼此不相临,需要进行标号区分
# 数字相同表示在shift之前区域相邻
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
"""
# 也可以按如下分区
# 数字相同表示在shift之前区域相邻
h_slices = (slice(0, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.shift_size),
slice(-self.shift_size, None))
"""
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# windows 划分
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
# attn_mask 计算
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
window实际划分时不是按照中线四等份划分的
更新:SwinTransformer中下采样非四等份中线划分,window划分应该是按照中线四等份划分的。