resnet论文_【SOT】Siamese Mask论文和代码解析

58abedd99db4e324e17821c76ef6aff4.png

0. 前言

前几期已经对单目标检测领域(SOT)的一些成果进行了论文的阅读代码的解析,总结如下:

周威:【SOT】siameseFC论文和代码解析​zhuanlan.zhihu.com
287c8a0b2cd575bd5cb005d13c40f983.png
周威:【SOT】Siamese RPN论文解读和代码解析​zhuanlan.zhihu.com
9c066100e4a2c44a1e767ae5fd70beff.png
周威:【SOT】Siamese RPN++ 论文和代码解析​zhuanlan.zhihu.com
3cb0268d0d5cc13367d0b2c4dac89e01.png

本文将结合代码和论文对Siamese Mask进行详细解析,进一步了解Siamese家族用于SOT的重要成员之一,即Siamese Mask。

这么我们先放官方给的一段视频。

1a9d323f030896551ac1224dbd675865.gif

从视频中可以看出,我们只需要给出视频的第一帧中目标位置,更准确地,使用bounding box来框出目标的位置,并不需要手动mask。那么该模型(Siamese Mask)能够在后续的帧中对框出的物体进行追踪,并实时地进行Mask顺便能够给出实时旋转bounding box。

看了上面的视频,我们不禁会有以下疑问:

  • (1)第一帧给出bounding box,后面怎么就可以直接mask?
  • (2)这个旋转的bounding box是怎么样来的?
  • (3)按道理来说,Mask都是比较慢的,如何做到实时性呢?

都考虑到这里了,不如看看论文和代码。论文和代码链接如下:

Paper :https://arxiv.org/abs/1812.05050

Code: https://github.com/foolwood/SiamMask

本文将通过以下两个方向对论文进行解析,分别为

  • (1)Siamese Mask网络结构解析
  • (2)Siamese Mask网络损失函数设定

1. Siamese Mask网络结构解析

论文中给出了一个简化版的Siamese Mask网络结构图。

f7986e5a110c257924def3f1e5c7a804.png

大致的结构主要由以下部分构成:

  • (1)改进的ResNet-50作为Siamese Network 特征提取网络
  • (2)Depth-Wise Cross-correlation 获取的Response Map
  • (3)三分支或者两分支的Head

大致的前向(Inference)流程是这样的:

  • 输入图片template image(127x127x3)和Search image(255x255x3)分别被输入到特征提取网络
    中,获得15x15x256以及31x31x256大小的特征图
  • 为了降低参数运算,Siamese Mask借鉴了Siamese RPN++中的Depth-Wise Cross-correlation (也就是图中的*d),将15x15x256大小特征图作为卷积核,与31x31x256做互相关运算,获得17x17x256大小的response map
  • 将获取的response map通过三个分支分别获取mask、bbox regression 以及前景背景分类得分score信息,或者通过两个分支获取mask以及每个RoW的得分score信息。

(1)改进的ResNet-50作为Siamese Network 特征提取网络

更具体地,论文中提到

Network architecture. For both our variants, we use a ResNet-50 until the final convolutional layer of the 4-th stage as our backbone
. In order to obtain a high spatial resolution in deeper layers, we reduce the output stride to 8 by using convolutions with stride 1. Moreover, we increase the
receptive field by using dilated convolutions . In our model, we add to the shared backbone
an unshared adjust layer (1x1 conv with 256 outputs)

可见,论文是对ResNet-50的网络结构进行了一些改进,并将其作为特征提取网络

。有关将ResNet-50作为Siamese Network特征提取网络而造成
平移不变性被破坏的问题,我们在Siamese RPN++网络的解析中已经说明过了,有兴趣可以看看。

这里简单说一下,论文中提到:

During training, we randomly jitter examplar and search patches. Specifically, we consider random translations (up to
8 pixels) and
rescaling (of
and
for examplar and search respectively)

上面的random就是缓解ResNet-50等(带padding的)网络带来的平移不变性被破坏的问题。

在Siamese RPN++解析过程中已经给出改进的ResNet-50代码了,这里为了不占过多篇幅,不再多说。(其实是在代码中没有找到罢了)

(2) Depth-Wise Cross-correlation 获取的Response Map

这里我们仍然借用解析Siamese RPN++时候使用到的Depth-Wise Cross-correlation图示,如下图c所示。

259acd021dcca86dec309ad101181a2f.png

核心思想就是分组卷积,这在代码实现过程中,使用grounps这个参数进行设置即可。代码实现如下:

def xcorr_depthwise(x, kernel):
    """depthwise cross correlation
    """
    batch = kernel.size(0)
    channel = kernel.size(1)
    x = x.view(1, batch*channel, x.size(2), x.size(3))
    kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
    out = F.conv2d(x, kernel, groups=batch*channel)
    out = out.view(batch, channel, out.size(2), out.size(3))
    return out

可见代码

out = F.conv2d(x, kernel, groups=batch*channel)

中使用了groups参数。

具体有关groups参数设置,读者可以去CSDN搜一下,比较简单,不多赘述。

分组卷积的优点就是可以明显降低模型的参数量提高模型运行速度,特别是在算力不是很强的移动设备上。

(3)三分支或者两分支的Head

有关模型的head分支,文章中给出了两种,分别为

  • three-branch variant head
  • two-branch variant head

这里我们只对三分支进行解析(其实二分支和三分支也差不多)。

第一张图中已经给出了大致的结构。

f7986e5a110c257924def3f1e5c7a804.png

论文中提到:

In particular, in our case this representation corresponds to one of the (17x17) RoWs produced by the depth-wise cross-correlation between
and
. Importantly, the network
of the segmentation task is composed of
two 1x1 convolutional layers, one with 256 and the other with
channels (Figure 2).

大致的意思就是,通过上面的Depth-Wise Cross-correlation 获得了17x17x256的特征图,我们以该特征图为输入,通过2个1x1卷积

),提升特征图的维度至
。然后我们沿着
维度所在的dim,将
的特征图分为17x17个RoWs。

接着论文提及到:

This allows every pixel classifier to utilise information contained in the entire RoW and thus to have a complete view of its corresponding candidate window in x, which is critical to disambiguate between instances that look like the target (e.g. last row of Figure 4), often referred to as distractors.

意思就是希望获得的17x17个RoWs中能够有一个RoW,他包含了mask全部信息。然后通过这个RoW中的信息(维度为

),将其
映射回原图大小,便获取了物体的mask信息。

但是这种方法其实精度不是很好,我们在学习语义分割的一些模型时(如FCN、U-Net等),通过会考虑不同感受野特征图的融合来提高分割的精度。作者论文也提到:

With the aim of producing a more accurate object mask, we follow the strategy of [44], which merges low and high resolution features using multiple refinement modules made of upsampling layers and skip connections。

也就是利用了上采样skip connections提高分割精度。文章的附录中也给出了结构图。

d664261ac696b5185a26608f6e4cdece.png

从结构图可以看出来,该方法没有对获得的response map进行1x1升维

,而是直接选择
最好的一个RoW直接进行 上采样,这里选择了 反卷积进行上采样。然后不断地与中间层的特征图 进行融合,获得高精度的分割模型。

在Siamese Mask中,作者使用mask来获得旋转的bounding box 的,也就是先获得mask,然后根据mask选择最小外接矩阵,为其bounding box即可。论文中这么说的。、

We consider three different strategies to generate a bounding box from a binary mask (Figure 3): (1) axis-aligned bounding rectangle (Min-max), (2) rotated minimum bounding rectangle (MBR) and (3) the optimisation strategy used for the automatic bounding box generation proposed in VOT-2016 [26] (Opt).

作者提到他考虑了三种办法,效果如下图所示

2873c0926d0adc8707aabcd6b78d1167.png

其中红色框是第一中方法(轴对齐),绿色就是MBR,蓝色的就是the optimisation strategy,从图上看,个人认为绿色的比较好。

至此,有关Siamese Mask的结构就解析完毕了。

2. Siamese Mask网络损失函数设定

Siamese Mask网络损失函数设定也比较简单,这里我们只对mask的损失函数进行解析。

论文中给出了该损失函数,如下:

366d78c1cc863bbb5e86f260ef439c5b.png

论文中是这样解释的

828accb6599f8785433313a7d25b3872.png

也就是在训练过程中,每个RoW(共17x17个RoWs)被标记为

,也就是公式中的
,这个n就是RoW的编号ID。

因为这是一个mask分割的任务,那么mask的ground truth 就是和原图大小一致的二值化图(大小为

,每个像素点的值不是1就是-1,用
表示),同时模型的mask分支的输出大小也为
,每个像素点的值用
表示。

那么损失函数

366d78c1cc863bbb5e86f260ef439c5b.png

含义如下:

  • (1)当
    =-1时,说明这个RoW是负样本,那么1+
    =0,不考虑计算损失
  • (2)当
    =1 时,说明这个RoW是正样本,使用 logistic regression loss来计算mask loss。

代码实现为:

loss = F.soft_margin_loss(p_m, mask_uf)

调用了soft_margin_loss损失即可。

至此,对mask 损失的讲解也结束了。

3. 总结

本文我们对Siamese Mask进行了详细的解析,本文更侧重对论文的解读。有关代码上的解析,我只是稍微扫了一眼,并没仔细看。后面的论文解析应当更侧重于论文的理解,代码只是辅助工具,要不总是一大堆的代码复制,我自己复习的时候看的都头疼啊哈哈。至此,有关Siamese Network在SOT任务一些重要网络我们就讲解结束了,后面可能要进入多目标检测领域MOT的学习了。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值