caffe 学习笔记
生成caffe的solver文件需要调用 caffe_pb2 中的类 SolverParameter:
from caffe.proto import caffe_pb2
# 实例化配置类
s = caffe_pb2.SolverParameter()
# 网络
s.train_net = "train.prototxt"
s.test_net.append("test.prototxt")
# 测试间隔
s.test_interval = 100
s.test_iter.append(10)
# 最大迭代
s.max_iter = 1000
# 学习率
s.base_lr = 0.1
# 学习率衰减系数
s.weight_decay = 5e-4
# 学习率衰减策略
s.lr_policy = "step"
# 打印间隔
s.display = 10
# 保存参数间隔
s.snapshot = 10
# 保存位置
s.snapshot_prefix = "model"
# 优化策略
s.type = "SGD"
# GPU
s.solver_mode = caffe_pb2.SolverParameter.GPU
# 生成配置文件
with open("net/s.prototxt","w") as f:
f.write(str(s))