《Shunted Transformer》-- 代码笔记

论文地址:https://arxiv.org/pdf/2111.15193.pdf

代码地址:https://github.com/OliverRensu/Shunted-Transformer

        模型是通过 SSA.py 文件中利用 @register_model 方法定义:

        具体流程如下:

        step1: model = ShuntedTransformer()

@register_model
def shunted_t(pretrained=False, **kwargs):
    model = ShuntedTransformer(
        patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 2, 4, 1], sr_ratios=[8, 4, 2, 1], num_conv=0,
        **kwargs)
    model.default_cfg = _cfg()

    return model

        step2: Class ShuntedTransformer()

class ShuntedTransformer(nn.Module):
    """省略了一些简单的定义"""
     def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, num_conv=0):
    """..."""

    def forward_features(self, x):
        B = x.shape[0]

        for i in range(self.num_stages):    #    论文里的stage一共是4个。
            patch_embed = getattr(self, f"patch_embed{i + 1}")    #    patch_embed都是通过Conv2d实现;
            block = getattr(self, f"block{i + 1}")    # block = nn.ModuleList([Block() for j in range(depth[I]) ])
            norm = getattr(self, f"norm{i + 1}")    # nn.LayerNorm
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            x = norm(x)
            if i != self.num_stages - 1:
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        return x.mean(dim=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)    # self.head = nn.Linear(embed_dim, num_class)

        return x

        step3: Class Blcok()        

class Block(nn.Module):

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))    # self.attn = Attention()
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x

        step4: Class Attention():重要!!

        论文中的公式如下:\begin{aligned} Q_{i} &=X W_{i}^{Q} \\ K_{i}, V_{i} &=M T A\left(X, r_{i}\right) W_{i}^{K}, M T A\left(X, r_{i}\right) W_{i}^{V}, \\ V_{i} &=V_{i}+\operatorname{LE}\left(V_{i}\right) \end{aligned}{\color{Blue} }

        其中:MTA()表示token聚合,LE是分组卷积;

        ​​​​​​​

class Attention(nn.Module):
    def forward(self, x, H, W):
        B, N, C = x.shape   # x = (b,n,c)
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # self.q() = nn.Linear(dim,dim)
        # q = (b, heads=8, n, c/heads);
        if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W) # x_ = (b,c,h,w)
                # 为什么要做两个x_1,x_2 ?
                # sr1:Conv2d(当r=8时,k=8,s=8)
                # sr2:Conv2d(当r=8时,k=4,s=4)
                x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1)))
                x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1)))
                # x_1 = (b, hw/8*8, c);
                # x_2 = (b, hw/4*4, c);

                kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
                kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
                #   self.kv1 = nn.Linear(dim,dim)
                #   kv1 = (2, b, heads/2, hw/8*8, c/heads)  第1,3,5项乘起来是c,其实就是(b,c,hw/r*r),然后把通道数分开;
                #   kv2 = (2, b, heads/2, hw/4*4, c/heads)
                k1, v1 = kv1[0], kv1[1] # ( b, heads/2, hw/8*8, c/heads)
                k2, v2 = kv2[0], kv2[1] # ( b, heads/2, hw/4*4, c/heads)
                # @表示矩阵乘法;
                attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scale
                # attn1 = q:(b, :heads/2, n, c/heads) @ kv1 = (b, heads/2, c/heads, hw/r*r) = (n, hw/r*r)
                attn1 = attn1.softmax(dim=-1)
                attn1 = self.attn_drop(attn1)

                v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2).
                                        transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\
                    view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
                    # v1 =( b, heads/2, hw/8*8, c/heads)  
                    # -> self.local_conv1(分组卷积: 将每个通道分成一个组,尺寸不变)= ( b, c/2, h/8, w/8)
                    # -> ( b, heads/2, hw/8*8, c/heads)
                x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2)   # x1 = (b,n,c/2) 这一步将hw消掉;


                attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scale
                attn2 = attn2.softmax(dim=-1)
                attn2 = self.attn_drop(attn2)
                v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2).
                                        transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\
                    view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
                x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2)

                x = torch.cat([x1,x2], dim=-1)
        else:
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Shunted Transformer是一种改进的Transformer模型,与传统的Transformer模型相比,它更加注重自注意层内场景对象的多尺度性质。Shunted Transformer通过多尺度的标记汇聚来实现。具体而言,Shunted Transformer模型分为四个阶段,每个阶段包含多个Shunted Transformer块。在每个阶段中,每个块输出相同大小的特征图。为了连接不同的阶段,采用了带有stride 2的卷积层,这会导致特征图的尺寸减半,但维度增加一倍。这种结构可以帮助模型更好地捕捉不同尺度的信息,提高对场景对象的建模能力。如果你对这个模型更加具体的实现细节感兴趣,可以参考相关的论文和Github链接。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [论文笔记(Shunted-Transformer)](https://blog.csdn.net/Karl51/article/details/129338845)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* [Shunted Transformer 飞桨权重迁移体验](https://blog.csdn.net/m0_63642362/article/details/124335344)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值