MixFormer(论文解读与代码讲解)1

论文地址代码地址代码调试问题

动机:打破CNN做backbone的僵局。

贡献:

1、依靠迭代的混合注意力机制(MAM)设计端到端的跟踪模块,可以代替CNN实现特征的提取,可以代替互相关操作实现模板与搜索图像之间的关联。

2、在MAM中设计了一个定制的非对称关注作用于在线模板更新,,并提出了一个有效的评分预测模块来选择高质量的模板。

3、性能优异

步骤流程

 Mixed Attention Module (MAM)

代码:前面都是一些配置文件在expr_func(settings)处正式进入文件train_script_mixformer.py的run函数,开始加载数据集和配置优化器。进入训练

 # train process
    trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)

注意:输入数据为模板(8组每组两张),搜索(8组每组一张)

 原因是,MixFormer提出了一种在线更新模块(SPM),所以需要两组模板图像用于训练SPM

训练框图:

 第一步,给每个模板和搜索加上位置编码

self.patch_embed = ConvEmbed(
            # img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            stride=patch_stride,
            padding=patch_padding,
            embed_dim=embed_dim,
            norm_layer=norm_layer
        )        
        template = self.patch_embed(template)
        online_template = self.patch_embed(online_template)
        t_B, t_C, t_H, t_W = template.size()
        search = self.patch_embed(search)
        s_B, s_C, s_H, s_W = search.size()
class ConvEmbed(nn.Module):
    """ Image to Conv Embedding

    """

    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d( #通道3->64, 卷积核大小(7, 7), 步长(4, 4), 填充(2, 2)
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):#x = template(8, 3, 128, 128)
        x = self.proj(x)#(8, 64, 32, 32)

        B, C, H, W = x.shape#获取尺度信息
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()#变形(8, (32*32), 64)
        if self.norm:
            x = self.norm(x) #layerNorm
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W).contiguous()
        #变形(8, 64, 32, 32)
        return x

首先,通过self.proj(2D卷积降维),再通过norm归一化。这里位置编码和平常的操作不一样,不是常规的卷积生成一个可学习的矩阵然后叠加上去,可能作者认为卷积操作的可学习型可以直接生成包含位置编码的信息,可以提供位置编码信息。值得一提的这里的norm用的layernorm,平常接触的都是batchnorm,这两者是有区别的(详细可参考

        template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
        online_template = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
        search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
        x = torch.cat([template, online_template, search], dim=1)

        x = self.pos_drop(x)

模板们和搜索做拼接(以组和通道做高宽,以原始的(w,h)拉伸作为拼接的通道)

 x = self.pos_drop(x),随机使一些数不作用,防止过拟合,增强泛化能力。

MixFormer(论文解读与代码讲解)2

  • 3
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值