来自:http://blog.csdn.net/xinfeng2005/article/details/53380700?locationNum=8&fps=1
# coding=utf-8
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
import struct
import pickle
def ImageToFloat(img):
return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255
#选择lenet
model=mx.model.FeedForward.load('lenet',10)
#选择MLP
#model=mx.model.FeedForward.load('mpl_mnist',10)
# 测试集单张图像识别情况
with open('./mnist/t10k-labels-idx1-ubyte')as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label = np.fromstring(flbl.read(), dtype=np.int8)
with open('./mnist/t10k-images-idx3-ubyte', 'rb')as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
plt.subplot(5,5,1)
for x in range(25):
plt.subplot(5,5,x+1)
plt.imshow(255-image[x], cmap='Greys_r')
prob = model.predict(ImageToFloat(image[x:x+1]))[0]
print'Classified as %d with probability %f' % (prob.argmax(), max(prob))
plt.title('%s %s'%(str(label[x]),str(max(prob))))
plt.axis('off')
plt.show()
val_iter = mx.io.NDArrayIter(ImageToFloat(image), label, batch_size=100)
print'Text accuracy: %f%%' % (model.score(val_iter) * 100,)
#训练集识别精度
with open('./mnist/train-labels-idx1-ubyte')as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label_train = np.fromstring(flbl.read(), dtype=np.int8)
with open('./mnist/train-images-idx3-ubyte', 'rb')as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image_train = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label_train), rows, cols)
plt.axis('off')
plt.imshow(255-image_train[0], cmap='Greys_r')
prob = model.predict(ImageToFloat(image_train[0:1]))[0]
print'Classified as %d with probability %f' % (prob.argmax(), max(prob))
plt.show()
train_iter = mx.io.NDArrayIter(ImageToFloat(image_train), label_train, batch_size=100)
print'Train accuracy: %f%%' % (model.score(train_iter) * 100,)
迭代10次后MLP:精度Text accuracy: 97.390000% Train accuracy: 98.821667%LeNet:精度Text accuracy: 99.170000% Train accuracy: 99.995000%