SwinTransformer-Segmentation 代码解读

Link: https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation

结构总览:

  • Backbone: Swin Transformer

    • Patch Embedding
      • 一系列 BasicLayer (Stage)

      • n个Swin Transformer Block
        • W-MSA / SW-MSA
        • FFN / MLP
    • Patch Merging
  • Decode_Head: UperHead

  • Auxiliary_Head: FCNHead

Patch Embedding (Patch Partition)

  1. Intention: split image into non-overlapping patches

  2. Just a conv2d: input shape: (B, C, H, W)  output shape: (B, embed_dim, Wh, Ww)

nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

BasicLayer

W-MSA / SW-MSA

# 1. Window partition
(B, H, W, C)
(B, H // window_size, window_size, W // window_size, window_size, C)
(B, H // window_size, W // window_size, window_size, window_size, C)
(B * H // window_size * W // window_size, window_size, window_size, C)

mask_windows’s shape : (C* B * H // window_size * W // window_size, window_size * window_size)

Attn_mask’s shape:  (C* B * H // window_size * W // window_size, 1, window_size * window_size) - (C* B * H // window_size * W // window_size, window_size * window_size, 1)

Efficient batch computation v1

  1. number of windows increase: ceil(h/M) * ceil(w/M) -> ceil(h/M+1) * ceil(w/M+1)

  2. window’s size diverse and are all small than the original one (M,M)

How to do batch computation efficiently??? Padding (add more computation) No!

The answer is Cyclic Shift

Cyclic Shift 

Now what we have to do is to do self-attention in window 1-9,

In order to do self-attention just like what did in M-WSA, we roll all window M//2

move every window to left M/2 and to top M/2, and then we can calculate the 5th window with W-MSA method, but other windows will get false results, with (6,4) mixed, (8,2) mixed, (1,3,7,9) mixed.

In order to fix the mixing problem, we have to add mask when doing self-attention, 

def generate_attn_mask():

    # 1. generate id for every effective window from 0-8
    Hp = int(np.ceil(H / window_size)) * window_size
    Wp = int(np.ceil(W / window_size)) * window_size
    img_mask = torch.zeros([1, Hp, Wp, 1])
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    
    cnt = 0
    for h in h_slices:
        for w in slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    # 2. generate true window(window_size, window_size) from img_mask
    # ref: https://zhuanlan.zhihu.com/p/370766757
    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size
    mask_windows = mask_windows.view(-1, window_size * window_size) # nW,  window_size*window_size
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW,  window_size*window_size, window_size*window_size
    # sigmoid -> 0 when x is very small!!
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)

 finally we got mask and then we can make it parallel~

SwinTransformerBlock

structure

Just the W-MSA is the difference from the traditional transformer block.

LayerNorm -> W-MSA -> LayerNorm -> MLP

shortcut = x
x = self.norm(x)
x = padding(x)

if shift_size > 0:
    # cyclic shift which is already in attn_mask
    x = torch.roll(x, (-shift_size, -shift_size), dims=(1,2))
    

# window partitions
x = window_partition(x, window_size)
x = x.view(-1, window_size*window_size, C)

# do window/shifted_window attention  Parallel SW-MSA
if shift_size > 0:
    x = self.attention(x, mask=attn_mask)
else:
    x = self.attention(x)

x = x.view(-1, window_size, window_size, C)
x = window_reverse(x, window_size)

if shift_size > 0:
    x = torch.roll(x, (shift_size, shift_size), dims=(1,2))

# LN + LayerNorm
x = shortcut + self.dropout(x)
x = x + self.dropout(self.mlp(self.norm2(x)))

 Patch Merging

ref: https://zhuanlan.zhihu.com/p/367111046

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

x = padding(x)  # (N, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # (N, H/2, W/2, C)
x1 = x[:, 0::2, 1::2, :]  # (N, H/2, W/2, C)
x2 = x[:, 1::2, 0::2, :]  # (N, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :]  # (N, H/2, W/2, C)

x = torch.cat([x0, x1, x2, x3], axis=-1)  # (N, H/2, W/2, 4C)
x = x.view(N, -1, 4C)
x = self.norm(x)
x = self.reduction(x)  # from 4C to 2C

 FCNHead

the pooling feature in swin transformer is every output feature from every BasicLayer / Stage

All Reference

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值