caffe 进行手写数字训练

案例数据准备

下载
链接:https://pan.baidu.com/s/10CmpZUdEVmma4A0mziu9dw
提取码:dmjr
复制这段内容后打开百度网盘手机App,操作更方便哦
解压后放到data/mnist

进入C:\Windows\System32\WindowsPowerShell\v1.0
管理员运行PowerShell
PS F:\caffe-windows> examples\mnist\create_mnist.ps1
生成两个目录
在这里插入图片描述
之后将mnist拷贝到自己的工程目录备用

运行

#coding='utf-8'
import lmdb
import caffe
from caffe.proto import caffe_pb2
from caffe import layers as L
from caffe import params as P
from matplotlib import pyplot as plt
import numpy as np


		 
solver_file = './mnist/lenet_auto_solver.prototxt'
train_proto = "./mnist/lenet_auto_train.prototxt"
train_lmdb = "./mnist/mnist_train_lmdb"
test_proto = "./mnist/lenet_auto_test.prototxt"
test_lmdb = "./mnist/mnist_test_lmdb"

def lenet(lmdb,batch_size):
	n = caffe.NetSpec()
	n.data,n.label = L.Data(batch_size=batch_size,backend=P.Data.LMDB,source=lmdb,transform_param=dict(scale=1./255),ntop=2)

	n.conv1 = L.Convolution(n.data,num_output=20,kernel_size=5,weight_filler=dict(type='xavier'))
	n.pool1 = L.Pooling(n.conv1,pool=P.Pooling.MAX,kernel_size=2,stride=2)

	n.conv2 = L.Convolution(n.pool1,num_output=50,kernel_size=5,weight_filler=dict(type='xavier'))
	n.pool2 = L.Pooling(n.conv2,pool=P.Pooling.MAX,kernel_size=2,stride=2)

	n.fc1 = L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier'))
	n.relu1 = L.ReLU(n.fc1,in_place=True)

	n.score = L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier'))
	n.loss = L.SoftmaxWithLoss(n.score,n.label)

	return n.to_proto()

def gen_solver(solver_file, train_proto, test_net_file=None):
	s = caffe_pb2.SolverParameter()
	s.train_net = train_proto
	if not test_proto:
		s.test_net.append(train_proto)
	else:
		s.test_net.append(test_proto)
	s.test_interval = 500  
	s.test_iter.append(100) 
	s.display = 500
	s.max_iter = 10000	 
	s.base_lr = 0.001	   # 基础学习率
	s.momentum = 0.9		# momentum系数
	s.weight_decay = 5e-4	   # 正则化权值衰减因子,防止过拟合

	s.lr_policy = 'step'		# 学习率衰减方法
	s.stepsize=1000		 # 只对step方法有效, base_lr*gamma^floor(iter/stepsize)
	s.gamma = 0.1
	s.display = 500		 # 输出日志间隔迭代次数
	s.snapshot = 5000	  # 在指定迭代次数时保存模型
	s.snapshot_prefix = 'mnist/lenet'
	s.type = 'SGD'  # 迭代算法类型, ADADELTA, ADAM, ADAGRAD, RMSPROP, NESTEROV
	s.solver_mode = caffe_pb2.SolverParameter.GPU

	with open(solver_file, 'w') as f:
		f.write(str(s))




def write_data(train_proto,train_lmdb,test_proto,test_lmdb):
	with open(train_proto,'w') as f:
		f.write(str(lenet(train_lmdb,64)))

	with open(test_proto,'w') as f:
		f.write(str(lenet(test_lmdb,100)))


def main():
	write_data(train_proto,train_lmdb,test_proto,test_lmdb)
	#caffe.set_device(0)#使用GPU,我的caffe GPU配置失败,暂时使用CPU
	#caffe.set_mode_gpu()
	gen_solver(solver_file, train_proto, test_net_file=None)

	solver = None
	solver = caffe.SGDSolver(solver_file)

	#for k,v in solver.net.blobs.items():
	#	print(k,v.data.shape)

	test_interval = 20
	niter = 200

	train_loss = np.zeros(niter)
	test_acc = np.zeros(int(np.ceil(niter/test_interval)))

	output = np.zeros((niter,8,10))

	for it in range(niter):
		solver.step(1)
		train_loss[it] = solver.net.blobs['loss'].data
		solver.test_nets[0].forward(start='conv1')
		output[it]  = solver.test_nets[0].blobs["score"].data[:8]

		if it%test_interval == 0:
			print("run test ing...")
			correct = 0

			for test_it in range(100):
				solver.test_nets[0].forward()
				correct+=sum(solver.test_nets[0].blobs['score'].data.argmax(1) == solver.test_nets[0].blobs['label'].data)

			test_acc[it//test_interval] = correct/1e4

	_, ax1 = plt.subplots()
	ax2 = ax1.twinx()
	ax1.plot(np.arange(niter), train_loss)
	ax2.plot(test_interval*np.arange(len(test_acc)), test_acc, 'r')
	ax1.set_xlabel('iteration')
	ax1.set_ylabel('train loss')
	ax2.set_ylabel('test accuracy')
	ax2.set_title('Test accuracy:{:.2f}'.format(test_acc[-1]))
	_.savefig('result1.png')

main()

生成训练测试网络以及solver文件

在这里插入图片描述

利用上图网络及solver文件进行训练

在这里插入图片描述

生成结果图片

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佐倉

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值