论文解析 https://blog.csdn.net/CsdnWujinming/article/details/88895146
项目地址:https://github.com/foolwood/SiamMask
网络结构
代码解读
1. Custom.py SiamMask网络具体实现
ResDownS
ResNet 特征提取后,下采样输入adjust层,输入1024通道,输出256。对应ResNet-50后一个操作,表中adjust层。在ResDowno中调用。
class ResDownS(nn.Module):
#inplane:输入通道数
#outplane:输出通道数
def __init__(self, inplane, outplane):
super(ResDownS, self).__init__()
self.downsample = nn.Sequential(nn.Conv2d(inplane, outplane, kernel_size=1, bias=False), nn.BatchNorm2d(outplane))
def forward(self, x):
x = self.downsample(x)
if x.size(3) < 20:
l, r = 4, -4
x = x[:, :, l:r, l:r]
return x
ResDown
孪生网络特征提取层,对应图2中ResNet-50和adjust 操作
class ResDown(Features):
def __init__(self, pretrain=False):
super(ResDown, self).__init__()
self.features = resnet50(layer3=True, layer4=False)
if pretrain:
load_pretrain(self.features, 'resnet.model')
self.downsample = ResDownS(1024, 256)
def forward(self, x):
output = self.features(x)
p3 = self.downsample(output[-1])
return p3
def forward_all(self, x):
output = self.features(x)
p3 = self.downsample(output[-1])
return output, p3
UP(RPN)
边框回归和分类网络,实现过程调用DepthCorr对象。
class UP(RPN):
def __init__(self, anchor_num=5, feature_in=256, feature_out=256):
super(UP, self).__init__()
self.anchor_num = anchor_num
self.feature_in = feature_in
self.feature_out = feature_out
self.cls_output = 2 * self.anchor_num
self.loc_output = 4 * self.anchor_num
#feature_in:cls网络输入通道数
#feature_out:cls网络隐藏层通道数
#cls_output:cls输出通道数
self.cls = DepthCorr(feature_in, feature_out, self.cls_output)
self.loc = DepthCorr(feature_in, feature_out, self.loc_output)
def forward(self, z_f, x_f):
cls = self.cls(z_f, x_f)
loc = self.loc(z_f, x_f)
return cls, loc
MaskCorr
mask分支网络,同样调用DepthCorr对象,输入为256,输出为63*63通道数
class MaskCorr(Mask):
def __init__(self, oSz=63):
super(MaskCorr, self).__init__()
self.oSz = oSz
self.mask = DepthCorr(256, 256, self.oSz**2)
def forward(self, z, x):
return self.mask(z, x)
Refine
网络结构
图2上半部分
三个post属性分别对应U2,U3,U4。
class Refine(nn.Module):
def __init__(self):
"""
Mask refinement module
Please refer SiamMask (Appendix A)
https://arxiv.org/abs/1812.05050
"""
super(Refine, self).__init__()
self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 4, 3, padding=1), nn.ReLU())
self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 16, 3, padding=1), nn.ReLU())
self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 32, 3, padding=1), nn.ReLU())
self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(),
nn.Conv2d(4, 4, 3, padding=1), nn.ReLU())
self.deconv = nn.ConvTranspose2d(256, 32, 15, 15)
self.post0 = nn.Conv2d(32, 16, 3, padding=1)
self.post1 = nn.Conv2d(16, 4, 3, padding=1)
self.post2 = nn.Conv2d(4, 1, 3, padding=1)
def forward(self, f, corr_feature, pos=None):
p0 = torch.nn.functional.pad(f[0], [16,16,16,16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]
p1 = torch.nn.functional.pad(f[1], [8,8,8,8])[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31]
p2 = torch.nn.functional.pad(f[2], [4,4,4,4])[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15]
p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1)
#填充
out = self.deconv(p3)
out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31)))
out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61)))
out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127))) out = out.view(-1, 127*127)
return out
Custom
网络具体实现,继承自SiamMask类,方法未贴出
class Custom(SiamMask):
def __init__(self, pretrain=False, **kwargs):
super(Custom, self).__init__(**kwargs)
self.features = ResDown(pretrain=pretrain)
self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)
self.mask_model = MaskCorr()
self.refine_model = Refine()