前言
前面我们已经训练好了模型这时候模型文件是这样的:
我这里演示的所以设置的迭代次数不多的。
测试自己的数据
转换均值文件(将二进制的均值文件转换为npy这里要使用的均值文件):
import numpy as np
import caffe
import sys
BinaryMeanPath = '~~~~~~~~~~/mean.binaryproto'
NpyMeanOuPath = '~~~~~~~~~~/meannpy.npy'
print 'Start.............'
blob = caffe.proto.caffe_pb2.BlobProto()
data = open( BinaryMeanPath , 'rb' ).read()
blob.ParseFromString(data)
arr = np.array( caffe.io.blobproto_to_array(blob) )
out = arr[0]
np.save( NpyMeanOuPath , out )
print 'Complete.............'
然后caffe-windows\models\bvlc_reference_caffenet\这个文件夹中取出deploy这个文件,修改最后的输出数量改为我们自己的分类数量
最后,这里测试自己的数据的代码是直接从官网上拿下来的:
# coding:utf-8
import numpy as np
MyCaffeRoot = '~~~~~~~~~~/mymnist/'
ImgTestPath = '~~~~~~~~~~/1/pic_hashiqi_Pos120.jpg' #测试图片路径
LabelsPath = MyCaffeRoot + 'labels.txt'
import sys
import caffe
import os
CaffeModelPath = MyCaffeRoot + 'caffenet_train_iter_4500.caffemodel'
DeployPath = MyCaffeRoot + 'deploy.prototxt'
NpyMeanPath = '~~~~~~~~~~/mean/meannpy.npy'
if os.path.exists(CaffeModelPath) == False:
print u'找不到模型的路径'
else:
print u'找到模型的路径......'
caffe.set_mode_cpu();
net = caffe.Net(DeployPath, CaffeModelPath, caffe.TEST) #创建网络
#负载均衡减去均值
mu = np.load(NpyMeanPath)
mu = mu.mean(1).mean(1)
print u'各个颜色通道的均值:', zip('BGR', mu)
transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape})
transformer.set_transpose('data',(2, 0, 1))
transformer.set_mean('data',mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data',(2, 1, 0))
net.blobs['data'].reshape(50, 3, 227, 227)
#执行测试
out = net.forward()
# transform it and copy it into the net
image = caffe.io.load_image(ImgTestPath)
net.blobs['data'].data[...] = transformer.preprocess('data', image)
# perform classification
net.forward()
# obtain the output probabilities
output_prob = net.blobs['prob'].data[0]
#验证标签文件是否存在
if os.path.exists(LabelsPath) == False:
print u'标签文件不存在'
exit(0)
#读取标签文件
labels = np.loadtxt(LabelsPath, str, delimiter='\t')
sort = output_prob.argsort()[::-1][:2]
#labels[0] = '哈士奇'
print output_prob
print u'这个是--->:' , str(labels[sort[0]]).decode('utf-8')
print u'这个是--->:' , str(labels[sort[1]]).decode('utf-8')