训练网络时迭代次数往往需要数万次,需要训练几天,如果突发什么意外(断电)训练停止了岂不要从头训练。
其实借用snapshot机制,比如每隔1万次迭代保存一下网络参数,然后下次训练判断有无snapshot,有的话从snapshot恢复参数就可以了。
下面是一个例子
#coding:utf-8
import caffe
import numpy as np
import os
from caffe.proto import caffe_pb2
import google.protobuf as pb2
maxiteration=0
class SolverWrapper(object):
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over he snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
"""
def __init__(self, solver_prototxt, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
self.solver = caffe.SGDSolver(solver_prototxt)
#判断输出路径是否有snapshot文件
flist=os.listdir('/home/hj/py-R-FCN/output/rfcn_end2end_ohem/voc_0712_trainval')
global maxiteration
if len(flist)>0:
maxit=0
maxid=0
#得到最新的snapshot文件
for i in range(len(flist)):
ite=int(flist[i].split('_')[-1].split('.')[0])
if ite>maxit:
maxit=ite
maxid=i
print 'resume training from iteration ' ,maxit
maxiteration=maxit
print 'maxiteration:',maxiteration
#拷贝参数到网络
self.solver.net.copy_from(os.path.join('/home/hj/py-R-FCN/output/rfcn_end2end_ohem/voc_0712_trainval',flist[maxid]))
#如果没有snapshot文件则用imagenet上训练得到的模型初始化参数
elif pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
...
...
...
#定义snapshot函数
def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
net = self.solver.net
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
#得到snapshot的文件名
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter+maxiteration) + '.caffemodel')
#加上路径
filename = os.path.join(self.output_dir, filename)
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
snapshot函数在你训练时可以每隔1万次运行一次