动机:打破CNN做backbone的僵局。
贡献:
1、依靠迭代的混合注意力机制(MAM)设计端到端的跟踪模块,可以代替CNN实现特征的提取,可以代替互相关操作实现模板与搜索图像之间的关联。
2、在MAM中设计了一个定制的非对称关注作用于在线模板更新,,并提出了一个有效的评分预测模块来选择高质量的模板。
3、性能优异
Mixed Attention Module (MAM)
代码:前面都是一些配置文件在expr_func(settings)处正式进入文件train_script_mixformer.py的run函数,开始加载数据集和配置优化器。进入训练
# train process
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
注意:输入数据为模板(8组每组两张),搜索(8组每组一张)
原因是,MixFormer提出了一种在线更新模块(SPM),所以需要两组模板图像用于训练SPM
训练框图:
第一步,给每个模板和搜索加上位置编码
self.patch_embed = ConvEmbed(
# img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
stride=patch_stride,
padding=patch_padding,
embed_dim=embed_dim,
norm_layer=norm_layer
)
template = self.patch_embed(template)
online_template = self.patch_embed(online_template)
t_B, t_C, t_H, t_W = template.size()
search = self.patch_embed(search)
s_B, s_C, s_H, s_W = search.size()
class ConvEmbed(nn.Module):
""" Image to Conv Embedding
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.proj = nn.Conv2d( #通道3->64, 卷积核大小(7, 7), 步长(4, 4), 填充(2, 2)
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):#x = template(8, 3, 128, 128)
x = self.proj(x)#(8, 64, 32, 32)
B, C, H, W = x.shape#获取尺度信息
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()#变形(8, (32*32), 64)
if self.norm:
x = self.norm(x) #layerNorm
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W).contiguous()
#变形(8, 64, 32, 32)
return x
首先,通过self.proj(2D卷积降维),再通过norm归一化。这里位置编码和平常的操作不一样,不是常规的卷积生成一个可学习的矩阵然后叠加上去,可能作者认为卷积操作的可学习型可以直接生成包含位置编码的信息,可以提供位置编码信息。值得一提的这里的norm用的layernorm,平常接触的都是batchnorm,这两者是有区别的(详细可参考)
template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
online_template = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
x = torch.cat([template, online_template, search], dim=1)
x = self.pos_drop(x)
模板们和搜索做拼接(以组和通道做高宽,以原始的(w,h)拉伸作为拼接的通道)
x = self.pos_drop(x),随机使一些数不作用,防止过拟合,增强泛化能力。