利用训练好的模型测试只需要把模型和数据准备好。
系统: ubuntu14.04
Mxnet: 0.904
1.模型和数据准备
2.模型加载测试
import mxnet as mx
sym,arg_params,aux_params = mx.model.load_checkpoint('vggnew',40)
mod = mx.mod.Module(symbol=sym,context=mx.gpu(),data_names=['data'],label_names=['softmax_label'])
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))])
mod.set_params(arg_params,aux_params)
需要输出标签还要准备一个synset.txt
文件,格式如图:
with open('synset.txt','r') as f:
labels = [l.rstrip() for l in f]
对图片进行处理
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
def get_image(url, show=False):
#url:图片路径
#show:是否显示图片
img = cv2.cvtColor(cv2.imread(url), cv2.COLOR_BGR2RGB)
if img is None:
return None
if show:
plt.imshow(img)
plt.axis('off')
# convert into format (batch, RGB, width, height)
img = cv2.resize(img, (224, 224))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
return img
def predict(url):
img = get_image(url, show=True)
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)]))
prob = mod.get_outputs()[0].asnumpy()
# print the top-5
prob = np.squeeze(prob)
prob = np.argsort(prob)[::-1]
top1=prob[0]#取概率最高的一类
print top1 #输入类别
print labels[top1] #输出标签
#批量测试
path = '/mxnet/tools/train-cat/2'
import os
for lists in os.listdir(path):
image = os.path.join(path,lists)
predict(image)
参考文献:
[1]http://mxnet.io/api/python/model.html?highlight=predict#mxnet.model.FeedForward.predict
[2]http://mxnet.io/tutorials/python/predict_image.html?highlight=predict