Learning to Refine Object Segments

40 篇文章 0 订阅
29 篇文章 0 订阅

在这里插入图片描述
源代码里面的版本
在这里插入图片描述

看下源代码

class SharpMask(nn.Module):
    def __init__(self, config=default_config, context=True):
        super(SharpMask, self).__init__()
        self.context = context  # with context
        self.km, self.ks = config.km, config.ks
        self.skpos = [6, 5, 4, 2]

        deepmask = DeepMask(config)
        deeomask_resume = os.path.join('exps', 'deepmask', 'train', 'model_best.pth.tar')
        assert os.path.exists(deeomask_resume), "Please train DeepMask first"
        deepmask = load_pretrain(deepmask, deeomask_resume)
        self.trunk = deepmask.trunk
        self.crop_trick = deepmask.crop_trick
        self.scoreBranch = deepmask.scoreBranch
        self.maskBranchDM = deepmask.maskBranch
        self.fSz = deepmask.fSz

        self.refs = self.createTopDownRefinement()  # create refinement modules

        nph = sum(p.numel() for h in self.neths for p in h.parameters()) / 1e+06
        npv = sum(p.numel() for h in self.netvs for p in h.parameters()) / 1e+06
        print('| number of paramaters net h: {:.3f} M'.format(nph))
        print('| number of paramaters net v: {:.3f} M'.format(npv))
        print('| number of paramaters total: {:.3f} M'.format(nph + npv))
 def createTopDownRefinement(self):
        # create horizontal nets
        self.neths = self.createHorizontal()

        # create vertical nets
        self.netvs = self.createVertical()

        refs = nn.ModuleList()
        refs.append(self.netvs[0])
        for i in range(len(self.skpos)):
            refs.append(self.refinement(self.neths[i], self.netvs[i+1]))
        refs.append(nn.Sequential(nn.ReflectionPad2d(1),
                                  nn.Conv2d(self.km // 2 ** (len(refs)-1), 1, 3)))

        return refs
    def createHorizontal(self):
        neths = nn.ModuleList()
        nhu1, nhu2, crop = 0, 0, 0
        for i in range(len(self.skpos)):
            h = []
            nInps = self.ks // 2 ** i # tj : 32 16 8 4 
            if i == 0:
                nhu1, nhu2, crop = 1024, 64, 0 if self.context else 0
            elif i == 1:
                nhu1, nhu2, crop = 512, 64, -2 if self.context else 0
            elif i == 2:
                nhu1, nhu2, crop = 256, 64, -4 if self.context else 0
            elif i == 3:
                nhu1, nhu2, crop = 64, 64, -8 if self.context else 0
            if crop != 0:
                h.append(nn.ZeroPad2d(crop))
            h.append(nn.ReflectionPad2d(1))
            h.append(nn.Conv2d(nhu1, nhu2, 3))
            h.append(nn.ReLU(inplace=True))

            h.append(nn.ReflectionPad2d(1))
            h.append(nn.Conv2d(nhu2, nInps, 3))
            h.append(nn.ReLU(inplace=True))

            h.append(nn.ReflectionPad2d(1))
            h.append(nn.Conv2d(nInps, nInps // 2, 3))

            neths.append(nn.Sequential(*h))
        return neths
    def createVertical(self):
        netvs = nn.ModuleList()
        netvs.append(nn.ConvTranspose2d(512, self.km, self.fSz))

        for i in range(len(self.skpos)):
            netv = []
            nInps = self.km // 2 ** i  # tj : 32 16 8 4
            netv.append(nn.ReflectionPad2d(1))
            netv.append(nn.Conv2d(nInps, nInps, 3))
            netv.append(nn.ReLU(inplace=True))

            netv.append(nn.ReflectionPad2d(1))
            netv.append(nn.Conv2d(nInps, nInps // 2, 3))

            netvs.append(nn.Sequential(*netv))

        return netvs
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值