caffe训练(6)
使用python生成solver.prototxt文件
solver.prototxt文件中各个参数的具体含义可见博客:caffe总结(十)solver.prototxt参数含义
以分析的cifar10_quick_solver.prototxt文件为例,使用python程序,生成这个文件。
1.代码如下:
# -*- coding: UTF-8 -*-
import caffe #导入caffe包
def write_sovler():
my_project_root = "D:/caffe-master/zzhld/" #my-caffe-project目录
sovler_string = caffe.proto.caffe_pb2.SolverParameter() #sovler存储
solver_file = my_project_root + 'solver.prototxt' #sovler文件保存位置
sovler_string.train_net = my_project_root + 'train.prototxt' #train.prototxt位置指定
sovler_string.test_net.append(my_project_root + 'test.prototxt') #test.prototxt位置指定
sovler_string.test_iter.append(100) #测试迭代次数
sovler_string.test_interval = 500 #每训练迭代test_interval次进行一次测试
sovler_string.base_lr = 0.001 #基础学习率
sovler_string.momentum = 0.9 #动量
sovler_string.weight_decay = 0.004 #权重衰减
sovler_string.lr_policy = 'fixed' #学习策略
sovler_string.display = 100 #每迭代display次显示结果
sovler_string.max_iter = 4000 #最大迭代数
sovler_string.snapshot = 4000 #保存临时模型的迭代数
sovler_string.snapshot_format = 0 #临时模型的保存格式,0代表HDF5,1代表BINARYPROTO
sovler_string.snapshot_prefix = 'examples/cifar10/cifar10_quick' #模型前缀
sovler_string.solver_mode = caffe.proto.caffe_pb2.SolverParameter.GPU #优化模式
with open(solver_file, 'w') as f:
f.write(str(sovler_string))
if __name__ == '__main__':
write_sovler()
-
特别注意的是:
-
上面代码首先需要更改路径,其余根据需要更改,不进行更改也可以运行出结果;
-
在编写路径时,我试验了几次必须要求斜杠是“/”,另外那个如果在windows中直接复制的话,会有问题。
-
2.运行结果:
训练模型
从第一篇笔记至此,我们已经了解到如何将jpg图片转换成Caffe使用的db(levelbd/lmdb)文件,如何计算数据均值,如何使用python生成solver.prototxt、train.prototxt、test.prototxt文件。接下来,就可以进行训练的最后一步,使用caffe提供的python接口训练生成模型。如果不进行可视化,只想得到一个最终的训练model,可以使用如下代码:
import caffe
my_project_root = "/home/Jack-Cui/caffe-master/my-caffe-project/" #my-caffe-project目录
solver_file = my_project_root + 'solver.prototxt' #sovler文件保存位置
caffe.set_device(0) #选择GPU-0
caffe.set_mode_gpu()
solver = caffe.SGDSolver(solver_file)
solver.solve()
现在,如何训练生成模型的简单步骤已经讲完。接下来,以mnist实例,整合所学内容,训练生成model,并使用生成的model进行预测。
原文链接:https://blog.csdn.net/c406495762/article/details/70306728