-- coding: utf-8 --
“””
Created on Fri Mar 24 17:15:01 2017
@author: 爱吃菠菜的大力士
“””
import os
class Net():
def __init__(self):
@staticmethod
def fileopt(filename,content):
fp=open(filename,'w')
fp.write(content)
fp.close()
def train(self):
#函数功能:网络的训练
self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
out=r' 2>&1 |tee '
_command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+out+self.log_train
DeepID.fileopt(self.shtrain,_command)
os.system(_command)
def resume(self,num1):
#函数功能:根据当前状态,在继续训练模型
self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
out=r' 2>&1 |tee '
_command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+'--snapshot=/home/coco/caffe/examples/ZFnet/snapshot_iter_' + str(num1) +'.solverstate'+' '+out+self.log_train
DeepID.fileopt(self.shtrain,_command)
os.system(_command)
def finetuning(self):
#函数功能:在已有的模型基础上,用自己的数据进行网络的微调
self.solver_proto = '/home/coco/caffe/examples/ZFnet/solver.prototxt'
out=r' 2>&1 |tee '
_command=self.caffepath+'/build/tools/caffe train '+'--solver='+self.solver_proto+' '+'--weights=/home/coco/caffe/examples/ZFnet/googleNet.caffemodel --gpu 0'+out+self.log_test
DeepID.fileopt(self.shtest,_command)
os.system(_command)
def renewSolver(self,b):
#函数功能:根据需求更改solver文件参数,如改变最大迭代次数等
f = open('/home/coco/caffe/examples/ZFnet/solver.prototxt', 'r+')
flist = f.readlines()
flist[8] = 'max_iter: '+ str(b) + '\n'
open('/home/coco/caffe/examples/ZFnet/solver.prototxt','w').writelines(flist)
f.close()
def renewTrain(self,mask):
#函数功能:修改配置文件,如训练数据的地址及中值文件
if mask == 1:
newTrain = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_112_224"'
newVal = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_112_224"'
elif mask == 2:
newTrain = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_224_112"'
newVal = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_224_112"'
else:
newTrain = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/train_4_lmdb_224_224"'
newVal = ' source: "/home/coco/caffe/examples/ZFnet/lmdbDATA/val_4_lmdb_224_224"'
f = open('/home/coco/caffe/examples/ZFnet/train_val.prototxt', 'r+')
flist = f.readlines()
flist[10] = newTrain + '\n'
flist[24] = newVal + '\n'
open('/home/coco/caffe/examples/ZFnet/train_val.prototxt','w').writelines(flist)
f.close()
def demo()
#Net.train()
#Net.finetuning()
#Net.resume()
#代码块功能:用不同的数据交叉训练网络,如采用不同尺度的数据 224*224,112*224,224*112等
#以使得网络适应不同的尺寸
#Net.renewSolver(200)
#Net.renewTrain(1)
#Net.train()
#Net.renewSolver(400)
#Net.renewTrain(2)
#Net.resume(200)
#Net.renewSolver(600)
#Net.renewTrain(3)
#Net.resume(400)
#for i in range(50):
#max_iter = 600 + 600 * i
#Net.renewSolver(max_iter + 200)
#Net.renewTrain(1)
#Net.resume(max_iter)
#Net.renewSolver(max_iter + 400)
#Net.renewTrain(2)
#Net.resume(max_iter + 200)
#Net.renewSolver(max_iter + 600)
#Net.renewTrain(3)
#Net.resume(max_iter + 400)
if name==’main‘:
demo()