源代码里面的版本
看下源代码
class SharpMask(nn.Module):
def __init__(self, config=default_config, context=True):
super(SharpMask, self).__init__()
self.context = context # with context
self.km, self.ks = config.km, config.ks
self.skpos = [6, 5, 4, 2]
deepmask = DeepMask(config)
deeomask_resume = os.path.join('exps', 'deepmask', 'train', 'model_best.pth.tar')
assert os.path.exists(deeomask_resume), "Please train DeepMask first"
deepmask = load_pretrain(deepmask, deeomask_resume)
self.trunk = deepmask.trunk
self.crop_trick = deepmask.crop_trick
self.scoreBranch = deepmask.scoreBranch
self.maskBranchDM = deepmask.maskBranch
self.fSz = deepmask.fSz
self.refs = self.createTopDownRefinement() # create refinement modules
nph = sum(p.numel() for h in self.neths for p in h.parameters()) / 1e+06
npv = sum(p.numel() for h in self.netvs for p in h.parameters()) / 1e+06
print('| number of paramaters net h: {:.3f} M'.format(nph))
print('| number of paramaters net v: {:.3f} M'.format(npv))
print('| number of paramaters total: {:.3f} M'.format(nph + npv))
def createTopDownRefinement(self):
# create horizontal nets
self.neths = self.createHorizontal()
# create vertical nets
self.netvs = self.createVertical()
refs = nn.ModuleList()
refs.append(self.netvs[0])
for i in range(len(self.skpos)):
refs.append(self.refinement(self.neths[i], self.netvs[i+1]))
refs.append(nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(self.km // 2 ** (len(refs)-1), 1, 3)))
return refs
def createHorizontal(self):
neths = nn.ModuleList()
nhu1, nhu2, crop = 0, 0, 0
for i in range(len(self.skpos)):
h = []
nInps = self.ks // 2 ** i # tj : 32 16 8 4
if i == 0:
nhu1, nhu2, crop = 1024, 64, 0 if self.context else 0
elif i == 1:
nhu1, nhu2, crop = 512, 64, -2 if self.context else 0
elif i == 2:
nhu1, nhu2, crop = 256, 64, -4 if self.context else 0
elif i == 3:
nhu1, nhu2, crop = 64, 64, -8 if self.context else 0
if crop != 0:
h.append(nn.ZeroPad2d(crop))
h.append(nn.ReflectionPad2d(1))
h.append(nn.Conv2d(nhu1, nhu2, 3))
h.append(nn.ReLU(inplace=True))
h.append(nn.ReflectionPad2d(1))
h.append(nn.Conv2d(nhu2, nInps, 3))
h.append(nn.ReLU(inplace=True))
h.append(nn.ReflectionPad2d(1))
h.append(nn.Conv2d(nInps, nInps // 2, 3))
neths.append(nn.Sequential(*h))
return neths
def createVertical(self):
netvs = nn.ModuleList()
netvs.append(nn.ConvTranspose2d(512, self.km, self.fSz))
for i in range(len(self.skpos)):
netv = []
nInps = self.km // 2 ** i # tj : 32 16 8 4
netv.append(nn.ReflectionPad2d(1))
netv.append(nn.Conv2d(nInps, nInps, 3))
netv.append(nn.ReLU(inplace=True))
netv.append(nn.ReflectionPad2d(1))
netv.append(nn.Conv2d(nInps, nInps // 2, 3))
netvs.append(nn.Sequential(*netv))
return netvs