用python编写的caffe网络工具:包括网络训练、微调及中断后继续训练功能

-- coding: utf-8 --

“””
Created on Fri Mar 24 17:15:01 2017

@author: 爱吃菠菜的大力士
“””
import os

class Net():

def __init__(self):


@staticmethod
def fileopt(filename,content):
    fp=open(filename,'w')
    fp.write(content)
    fp.close()

def train(self):
    #函数功能:网络的训练
    self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
    out=r' 2>&1 |tee '
    _command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+out+self.log_train
    DeepID.fileopt(self.shtrain,_command)
    os.system(_command)

def resume(self,num1):
    #函数功能:根据当前状态,在继续训练模型       
    self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
    out=r' 2>&1 |tee '
    _command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+'--snapshot=/home/coco/caffe/examples/ZFnet/snapshot_iter_' + str(num1) +'.solverstate'+' '+out+self.log_train
    DeepID.fileopt(self.shtrain,_command)
    os.system(_command)

def finetuning(self):
    #函数功能:在已有的模型基础上,用自己的数据进行网络的微调
    self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
    out=r' 2>&1 |tee '
    _command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+'--weights=/home/coco/caffe/examples/ZFnet/googleNet.caffemodel --gpu 0'+out+self.log_test
    DeepID.fileopt(self.shtest,_command)
    os.system(_command)

def renewSolver(self,b):
    #函数功能:根据需求更改solver文件参数,如改变最大迭代次数等        
    f = open('/home/coco/caffe/examples/ZFnet/solver.prototxt', 'r+')
    flist = f.readlines() 
    flist[8] = 'max_iter: '+ str(b) + '\n'
    open('/home/coco/caffe/examples/ZFnet/solver.prototxt','w').writelines(flist)
    f.close()  

def renewTrain(self,mask):
    #函数功能:修改配置文件,如训练数据的地址及中值文件
    if mask == 1:
        newTrain = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_112_224"'
        newVal = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_112_224"'
    elif mask == 2:
        newTrain = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_224_112"'
        newVal = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_224_112"'
    else:
        newTrain = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_224_224"'
        newVal = '    source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_224_224"'


    f = open('/home/coco/caffe/examples/ZFnet/train_val.prototxt', 'r+')
    flist = f.readlines() 
    flist[10] = newTrain + '\n'
    flist[24] = newVal + '\n'
    open('/home/coco/caffe/examples/ZFnet/train_val.prototxt','w').writelines(flist)
    f.close()              

def demo()
#Net.train()
#Net.finetuning()
#Net.resume()
#代码块功能:用不同的数据交叉训练网络,如采用不同尺度的数据 224*224,112*224,224*112等
#以使得网络适应不同的尺寸
#Net.renewSolver(200)
#Net.renewTrain(1)
#Net.train()
#Net.renewSolver(400)
#Net.renewTrain(2)
#Net.resume(200)
#Net.renewSolver(600)
#Net.renewTrain(3)
#Net.resume(400)
#for i in range(50):
#max_iter = 600 + 600 * i
#Net.renewSolver(max_iter + 200)
#Net.renewTrain(1)
#Net.resume(max_iter)
#Net.renewSolver(max_iter + 400)
#Net.renewTrain(2)
#Net.resume(max_iter + 200)
#Net.renewSolver(max_iter + 600)
#Net.renewTrain(3)
#Net.resume(max_iter + 400)

if name==’main‘:

demo()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值