Link: https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation
结构总览:
-
Backbone: Swin Transformer
- Patch Embedding
-
一系列 BasicLayer (Stage)
- n个Swin Transformer Block
- W-MSA / SW-MSA
- FFN / MLP
-
- Patch Merging
- Patch Embedding
-
Decode_Head: UperHead
-
Auxiliary_Head: FCNHead
Patch Embedding (Patch Partition)
-
Intention: split image into non-overlapping patches
-
Just a conv2d: input shape: (B, C, H, W) output shape: (B, embed_dim, Wh, Ww)
nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
BasicLayer
W-MSA / SW-MSA
# 1. Window partition
(B, H, W, C)
(B, H // window_size, window_size, W // window_size, window_size, C)
(B, H // window_size, W // window_size, window_size, window_size, C)
(B * H // window_size * W // window_size, window_size, window_size, C)
mask_windows’s shape : (C* B * H // window_size * W // window_size, window_size * window_size)
Attn_mask’s shape: (C* B * H // window_size * W // window_size, 1, window_size * window_size) - (C* B * H // window_size * W // window_size, window_size * window_size, 1)
Efficient batch computation v1
-
number of windows increase: ceil(h/M) * ceil(w/M) -> ceil(h/M+1) * ceil(w/M+1)
-
window’s size diverse and are all small than the original one (M,M)
How to do batch computation efficiently??? Padding (add more computation) No!
The answer is Cyclic Shift
Cyclic Shift
Now what we have to do is to do self-attention in window 1-9,
In order to do self-attention just like what did in M-WSA, we roll all window M//2
move every window to left M/2 and to top M/2, and then we can calculate the 5th window with W-MSA method, but other windows will get false results, with (6,4) mixed, (8,2) mixed, (1,3,7,9) mixed.
In order to fix the mixing problem, we have to add mask when doing self-attention,
def generate_attn_mask():
# 1. generate id for every effective window from 0-8
Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size
img_mask = torch.zeros([1, Hp, Wp, 1])
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 2. generate true window(window_size, window_size) from img_mask
# ref: https://zhuanlan.zhihu.com/p/370766757
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size
mask_windows = mask_windows.view(-1, window_size * window_size) # nW, window_size*window_size
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size*window_size, window_size*window_size
# sigmoid -> 0 when x is very small!!
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)
finally we got mask and then we can make it parallel~
SwinTransformerBlock
structure
Just the W-MSA is the difference from the traditional transformer block.
LayerNorm -> W-MSA -> LayerNorm -> MLP
shortcut = x
x = self.norm(x)
x = padding(x)
if shift_size > 0:
# cyclic shift which is already in attn_mask
x = torch.roll(x, (-shift_size, -shift_size), dims=(1,2))
# window partitions
x = window_partition(x, window_size)
x = x.view(-1, window_size*window_size, C)
# do window/shifted_window attention Parallel SW-MSA
if shift_size > 0:
x = self.attention(x, mask=attn_mask)
else:
x = self.attention(x)
x = x.view(-1, window_size, window_size, C)
x = window_reverse(x, window_size)
if shift_size > 0:
x = torch.roll(x, (shift_size, shift_size), dims=(1,2))
# LN + LayerNorm
x = shortcut + self.dropout(x)
x = x + self.dropout(self.mlp(self.norm2(x)))
Patch Merging
ref: https://zhuanlan.zhihu.com/p/367111046
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = padding(x) # (N, H, W, C)
x0 = x[:, 0::2, 0::2, :] # (N, H/2, W/2, C)
x1 = x[:, 0::2, 1::2, :] # (N, H/2, W/2, C)
x2 = x[:, 1::2, 0::2, :] # (N, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :] # (N, H/2, W/2, C)
x = torch.cat([x0, x1, x2, x3], axis=-1) # (N, H/2, W/2, 4C)
x = x.view(N, -1, 4C)
x = self.norm(x)
x = self.reduction(x) # from 4C to 2C
FCNHead
the pooling feature in swin transformer is every output feature from every BasicLayer / Stage