如下给出调用已经训练好的caffe模型进行测试的示例代码:
包括:模型调用,时间统计,结果写入txt文件。
import os
import caffe
import numpy as np
import time
caffe.set_mode_cpu()
root = '/home/xuqiong/makeall/caffe/'
deploy = root + 'examples/xq0523pm/shufflenet_deploy.prototxt'
caffe_model = root + 'examples/xq0523pm/shufflenet_train_iter_50000.caffemodel'
#mean_file = '/mnt/data2/xuqiong/data/split/mean.binaryprototxt
test = '/mnt/data2/xuqiong/data/split/test/'
filelist = []
dir = test + 'pristine_images/'
filenames = os.listdir(dir)
for fn in filenames:
fullfilename = os.path.join(dir, fn)
filelist.append(fullfilename)
count = 0
timei = 0
timef = 0
timem = 0
timea = 0
timei_all = 0
timef_all = 0
timem_all = 0
timea_all = 0
for i in range(0, len(filelist)):
img = filelist[i]
#Test(img)
caffe.set_mode_cpu()
net = caffe.Net(deploy, caffe_model, caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) # shape(1,3,28,28)
transformer.set_transpose('data', (2, 0, 1)) # (28,28,3) to (3,28,28)
#transformer.set_mean('data', np.load(mean_file).mean(1).mean(1)) # minus mean
transformer.set_raw_scale('data', 255) # [0,255]
transformer.set_channel_swap('data', (2, 1, 0)) # RGB2BGR
time0 = time.time()
im = caffe.io.load_image(img)
net.blobs['data'].data[...] = transformer.preprocess('data', im)
time1 = time.time()
out = net.forward()
time2 = time.time()
prob = net.blobs['fcout6'].data[0].flatten() # last layer(prob)
order = prob.argsort()[5] # argsort() small-big
time3 = time.time()
timei = time1 - time0
timef = time2 - time1
timem = time3 - time2
timea = timei + timef + timem
timei_all = timei + timei_all
timef_all = timef + timef_all
timem_all = timem + timem_all
timea_all = timea + timea_all
print("timei: ", timei*1000, "ms")
print("timef: ", timef*1000, "ms")
print("timem: ", timem*1000, "ms")
print("timea: ", timea*1000, "ms")
print("the class is: ", order)
f = open("/home/xuqiong/makeall/caffe/examples/xq0523pm/label.txt", "a+")
f.writelines(img + ' ' + str(order) + '\n')
#calcute accuracy
path = img.split('/')[-2]
if(path == 'pristine_images'):
if order == 5:
count = count + 1
elif(path[-1] == 1):
if order == 4:
count = count + 1
elif (path[-1] == 2):
if order == 3:
count = count + 1
elif (path[-1] == 3):
if order == 2:
count = count + 1
elif (path[-1] == 4):
if order == 1:
count = count + 1
elif (path[-1] == 5):
if order == 0:
count = count + 1
print("shufflenetv1, cpu")
print("ok count: ", count)
print("all count: ", len(filelist))
print("timei average: %.2f", timei_all*1000/len(filelist), "ms")
print("timef average: %.2f", timef_all*1000/len(filelist), "ms")
print("timem average: %.2f", timem_all*1000/len(filelist), "ms")
print("timea average: %.2f", timea_all*1000/len(filelist), "ms")
f.close()