R-FCN/Faster-rcnn使用snapshot继续训练

R-FCN/Faster-rcnn使用snapshot继续训练

r-fcn 和 faster-rcnn本身不能直接使用solverstate继续训练。
需要修改lib/fast_rcnn /train.py 和 tools/train_net.py
具体:
lib/fast_rcnn/train.py
修改init函数:

def __init__(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None,
                 ##########add#########
                 previous_state=None
                 ):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
            cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
            # RPN can only use precomputed normalization because there are no
            # fixed statistics to compute a priori
            assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

        if cfg.TRAIN.BBOX_REG:
            print 'Computing bounding-box regression targets...'
            self.bbox_means, self.bbox_stds = \
                    rdl_roidb.add_bbox_regression_targets(roidb)
            print 'done'

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)
            ##########add#########
        elif previous_state is not None:
            print ('Restoring State from {:s}').format(previous_state)
            self.solver.restore(previous_state)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

修改train_net函数:

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000,
              ##########add#########
              previous_state=None
              ):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                    ##########add#########
                    previous_state=previous_state
                              )

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

tools/train_net.py:
修改prase_argsh函数,添加如下:

parser.add_argument('--snapshot', dest='previous_state',
                            help='initialize with previous state',
                            default=None, type=str) 
  • 1
  • 2
  • 3

修改main函数:

train_net(args.solver, roidb, output_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters,
              ##########add#########
              previous_state=args.previous_state
              )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这样就可以直接使用snapshot了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值