论文地址:https://arxiv.org/pdf/2111.15193.pdf
代码地址:https://github.com/OliverRensu/Shunted-Transformer
模型是通过 SSA.py 文件中利用 @register_model 方法定义:
具体流程如下:
step1: model = ShuntedTransformer()
@register_model
def shunted_t(pretrained=False, **kwargs):
model = ShuntedTransformer(
patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 2, 4, 1], sr_ratios=[8, 4, 2, 1], num_conv=0,
**kwargs)
model.default_cfg = _cfg()
return model
step2: Class ShuntedTransformer()
class ShuntedTransformer(nn.Module):
"""省略了一些简单的定义"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, num_conv=0):
"""..."""
def forward_features(self, x):
B = x.shape[0]
for i in range(self.num_stages): # 论文里的stage一共是4个。
patch_embed = getattr(self, f"patch_embed{i + 1}") # patch_embed都是通过Conv2d实现;
block = getattr(self, f"block{i + 1}") # block = nn.ModuleList([Block() for j in range(depth[I]) ])
norm = getattr(self, f"norm{i + 1}") # nn.LayerNorm
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
if i != self.num_stages - 1:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x.mean(dim=1)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x) # self.head = nn.Linear(embed_dim, num_class)
return x
step3: Class Blcok()
class Block(nn.Module):
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # self.attn = Attention()
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
step4: Class Attention():重要!!
论文中的公式如下:
其中:MTA()表示token聚合,LE是分组卷积;
class Attention(nn.Module):
def forward(self, x, H, W):
B, N, C = x.shape # x = (b,n,c)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# self.q() = nn.Linear(dim,dim)
# q = (b, heads=8, n, c/heads);
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W) # x_ = (b,c,h,w)
# 为什么要做两个x_1,x_2 ?
# sr1:Conv2d(当r=8时,k=8,s=8)
# sr2:Conv2d(当r=8时,k=4,s=4)
x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1)))
x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1)))
# x_1 = (b, hw/8*8, c);
# x_2 = (b, hw/4*4, c);
kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
# self.kv1 = nn.Linear(dim,dim)
# kv1 = (2, b, heads/2, hw/8*8, c/heads) 第1,3,5项乘起来是c,其实就是(b,c,hw/r*r),然后把通道数分开;
# kv2 = (2, b, heads/2, hw/4*4, c/heads)
k1, v1 = kv1[0], kv1[1] # ( b, heads/2, hw/8*8, c/heads)
k2, v2 = kv2[0], kv2[1] # ( b, heads/2, hw/4*4, c/heads)
# @表示矩阵乘法;
attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scale
# attn1 = q:(b, :heads/2, n, c/heads) @ kv1 = (b, heads/2, c/heads, hw/r*r) = (n, hw/r*r)
attn1 = attn1.softmax(dim=-1)
attn1 = self.attn_drop(attn1)
v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2).
transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\
view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
# v1 =( b, heads/2, hw/8*8, c/heads)
# -> self.local_conv1(分组卷积: 将每个通道分成一个组,尺寸不变)= ( b, c/2, h/8, w/8)
# -> ( b, heads/2, hw/8*8, c/heads)
x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2) # x1 = (b,n,c/2) 这一步将hw消掉;
attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scale
attn2 = attn2.softmax(dim=-1)
attn2 = self.attn_drop(attn2)
v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2).
transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\
view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2)
x = torch.cat([x1,x2], dim=-1)
else: