R-FCN/Faster-rcnn使用snapshot继续训练
r-fcn 和 faster-rcnn本身不能直接使用solverstate继续训练。
需要修改lib/fast_rcnn /train.py 和 tools/train_net.py
具体:
lib/fast_rcnn/train.py
修改init函数:
def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None,
##########add#########
previous_state=None
):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori
assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
##########add#########
elif previous_state is not None:
print ('Restoring State from {:s}').format(previous_state)
self.solver.restore(previous_state)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
修改train_net函数:
def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000,
##########add#########
previous_state=None
):
"""Train a Fast R-CNN network."""
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
##########add#########
previous_state=previous_state
)
print 'Solving...'
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
tools/train_net.py:
修改prase_argsh函数,添加如下:
parser.add_argument('--snapshot', dest='previous_state',
help='initialize with previous state',
default=None, type=str)
- 1
- 2
- 3
修改main函数:
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters,
##########add#########
previous_state=args.previous_state
)
- 1
- 2
- 3
- 4
- 5
- 6
这样就可以直接使用snapshot了。