caffe利用snapshot从断点恢复训练

训练网络时迭代次数往往需要数万次,需要训练几天,如果突发什么意外(断电)训练停止了岂不要从头训练。

其实借用snapshot机制,比如每隔1万次迭代保存一下网络参数,然后下次训练判断有无snapshot,有的话从snapshot恢复参数就可以了。

下面是一个例子


#coding:utf-8
import caffe
import numpy as np
import os

from caffe.proto import caffe_pb2
import google.protobuf as pb2

maxiteration=0

class SolverWrapper(object):
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """

    def __init__(self, solver_prototxt, output_dir,
                 pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        self.solver = caffe.SGDSolver(solver_prototxt)
#判断输出路径是否有snapshot文件
        flist=os.listdir('/home/hj/py-R-FCN/output/rfcn_end2end_ohem/voc_0712_trainval')
        global maxiteration
        if len(flist)>0:
            maxit=0
            maxid=0
#得到最新的snapshot文件
            for i in range(len(flist)):
                ite=int(flist[i].split('_')[-1].split('.')[0])
                if ite>maxit:
                    maxit=ite
                    maxid=i
            print 'resume training from iteration ' ,maxit
            maxiteration=maxit
            print 'maxiteration:',maxiteration
#拷贝参数到网络
            self.solver.net.copy_from(os.path.join('/home/hj/py-R-FCN/output/rfcn_end2end_ohem/voc_0712_trainval',flist[maxid]))
#如果没有snapshot文件则用imagenet上训练得到的模型初始化参数
        elif pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)
        ...
        ...
        ...

#定义snapshot函数
    def snapshot(self):
        """Take a snapshot of the network after unnormalizing the learned
        bounding-box regression weights. This enables easy use at test-time.
        """
        net = self.solver.net


        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
#得到snapshot的文件名
        filename = (self.solver_param.snapshot_prefix + infix +
                    '_iter_{:d}'.format(self.solver.iter+maxiteration) + '.caffemodel')
#加上路径
        filename = os.path.join(self.output_dir, filename)
        net.save(str(filename))
        print 'Wrote snapshot to: {:s}'.format(filename)

snapshot函数在你训练时可以每隔1万次运行一次


caffe 在已有模型上继续训练



  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值