看下源代码
class DeepMask(nn.Module):
def __init__(self, config=default_config, context=True):
super(DeepMask, self).__init__()
self.config = config
self.context = context # without context
self.strides = 16
self.fSz = -(-self.config.iSz // self.strides) # ceil div tj : feature map size ceil([h,w] / 16)
self.trunk = self.creatTrunk()
updatePadding(self.trunk, nn.ReflectionPad2d)
# tj : 这里更新原来网络中Conv2d中padding的strategy, 原来网络的padding可能只是填充0,
# 这里更改为SymmetricPad2d或ReflectionPad2d
self.crop_trick = nn.ZeroPad2d(-16//self.strides) # for training
self.maskBranch = self.createMaskBranch()
self.scoreBranch = self.createScoreBranch()
npt = sum(p.numel() for p in self.trunk.parameters()) / 1e+06
npm = sum(p.numel() for p in self.maskBranch.parameters()) / 1e+06
nps = sum(p.numel() for p in self.scoreBranch.parameters()) / 1e+06
print('| number of paramaters trunk: {:.3f} M'.format(npt))
print('| number of paramaters mask branch: {:.3f} M'.format(npm))
print('| number of paramaters score branch: {:.3f} M'.format(nps))
print('| number of paramaters total: {:.3f} M'.format(npt + nps + npm))
def creatTrunk(self):
resnet50 = torchvision.models.resnet50(pretrained=True) # tj : 注意这里用的resnet50而不是vgg
trunk1 = nn.Sequential(*list(resnet50.children())[:-3])
trunk2 = nn.Sequential(
nn.Conv2d(1024, 128, 1), # 1024*10*10 -> 128*10*10, kernel大小影响后面两维
nn.ReLU(inplace=True),
nn.Conv2d(128, 512, self.fSz) # tj : 如果原图大小为160的话, fSz就为10 = 160 // 16, 128*10*10 -> 512*1*1
)
return nn.Sequential(trunk1, trunk2)
# tj : 注意这里的output已经为512 * 1 * 1, 这一步为两个branch所共享, 与原网络结构有点不同
# 请参见sharpmask文章中fig 3, 图d
def createMaskBranch(self):
maskBranch = nn.Sequential(
nn.Conv2d(512, self.config.oSz**2, 1), #kernel size 为1, 表明后两维size不变, 512*1*1->(56*56) * 1 *1
Reshape(self.config.oSz),#然后变成56*56的output
)
if self.config.gSz > self.config.oSz:
upSample = nn.UpsamplingBilinear2d(size=[self.config.gSz, self.config.gSz])
maskBranch = nn.Sequential(maskBranch, upSample)
return maskBranch
def createScoreBranch(self):
scoreBranch = nn.Sequential(
nn.Dropout(0.5),
nn.Conv2d(512, 1024, 1),
nn.Threshold(0, 1e-6), # do not know why
nn.Dropout(0.5),
nn.Conv2d(1024, 1, 1),
)
return scoreBranch