[转]Mxnet实现手写数字mnist数据集训练、预测指南

环境:

Anaconda3(64-bit),安装mxnet1.3.1,opencv_python-3.4.5.20-cp36-cp36m-win_amd64.whl(可选)

训练源码:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 16:30:15 2019

@author: houwenbin
"""

import numpy as np
import mxnet as mx
import logging

logging.getLogger().setLevel(logging.DEBUG)

batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

data = mx.sym.var('data') 
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
                symbol=lenet, 
                context=mx.cpu())

# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)

# save model params
#lenet_model.save_params("lenet_10.params");
#
lenet_model.save_checkpoint("lenet", 10, False);


预测源码:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:17:26 2019

@author: houwenbin
"""

import time
import mxnet as mx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#
prefix = 'lenet'
iteration = 10
img_name = './digit_8.jpg'
synsets = [0,1,2,3,4,5,6,7,8,9]

# imagenet 图像预处理
def load_image(img_name):
        
    #PIL
    #相关:scipy.misc.imread, scipy.ndimage.imread
    #misc.imread 提供可选参数mode,但本质上是调用PIL,具体的模式可以去看srccode或者document
    img = Image.open(img_name)
    if img is None:
        return None
    
    img = img.resize((28,28))
    img = np.array(img.convert('L'),'f') #读取图片,灰度化,转换为数组,L = 0.299R + 0.587G + 0.114B。'f'为float类型
    #统一使用plt进行显示,不管是plt还是cv2.imshow,在python中只认numpy.array,但是由于cv2.imread 的图片是BGR,cv2.imshow 时相应的换通道显示
    print(img.shape)
    plt.imshow(img)
    plt.show()
    #
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    return img


time0 = time.time()

# 加载 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(prefix, iteration)
# 重建模型
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data',(1,1,28,28))], label_shapes=mod._label_shapes) # 为输入数据分配内存
mod.set_params(arg, aux, allow_missing=True) # 加载模型参数

#
time1 = time.time()
print("模型加载和重建时间:{0}".format(time1 - time0))
#
#加载图片
img = load_image(img_name)
if img is None:
    exit()

print(img.shape)
#
time0 = time.time()
#
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)])) # img{NDArray 1x1x28x28}做简单的inference
#
time1 = time.time()
print("前向预测时间:{0}".format(time1 - time0))

#输出Top-5预测结果
print(mod.get_outputs())
prob = mod.get_outputs()[0].asnumpy() #取出结果
print("-------result-------", prob, prob.shape)
prob = np.squeeze(prob)
print("-------squeeze result-------", prob, prob.shape)
print("-------sorted prob--------", np.sort(prob)) # 从小到大排列
print("-------arg sorted prob--------", np.argsort(prob))
a = np.argsort(prob)[::-1] # 得到分类网络分类置信度的从大到小的结果
print("------top sorted index-------", a, a.shape)
if a is not None:
    for i in a[0:5]:
        print('probability=%f, class=%s' %(prob[i], synsets[i]))


数据准备:

使用画图工具,绘制一个128x128的黑色背景,用橡皮擦擦除待检测数字即可(本文是digit_8.jpg)。

运行结果:

in[1]:runfile('C:/Users/houwenbin/Documents/PythonProject/test_mnist.py', wdir='C:/Users/houwenbin/Documents/PythonProject')

模型加载和重建时间:0.0060160160064697266

(28, 28)

(1, 1, 28, 28)

前向预测时间:0.0010042190551757812

[

[[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]]

<NDArray 1x10 @cpu(0)>]

-------result------- [[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]] (1, 10)

-------squeeze result------- [3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

 4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05] (10,)

-------sorted prob-------- [4.0004899e-10 1.1044953e-08 1.3175709e-06 2.3288235e-06 3.0556594e-06

 4.1811345e-06 1.8727254e-05 3.0342795e-05 3.0720061e-05 9.9990916e-01]

-------arg sorted prob-------- [5 3 1 8 0 2 7 6 9 4]

------top sorted index------- [4 9 6 7 2 0 8 1 3 5] (10,)

probability=0.999909, class=4

probability=0.000031, class=9

probability=0.000030, class=6

probability=0.000019, class=7

probability=0.000004, class=2

in[2]:

要使用mxnet将自己的数据集图片做成mnist数据集,可以按照以下步骤进行: 1. 准备自己的数据集,包括图片以及对应的标签信息。图片应该是灰度图,且尺寸应该是28x28。 2. 安装mxnet,并导入必要的模块: ```python import mxnet as mx import numpy as np import os ``` 3. 定义一个函数来加载图片和标签数据: ```python def load_data(path): with open(os.path.join(path, 'data.npy'), 'rb') as f: data = np.load(f) with open(os.path.join(path, 'label.npy'), 'rb') as f: label = np.load(f) return data, label ``` 4. 定义一个函数来将图片和标签数据换为mxnet的数据格式: ```python def transform(data, label): return mx.nd.array(data), mx.nd.array(label) ``` 5. 加载自己的数据集,并将其换为mxnet的数据格式: ```python data, label = load_data('path/to/your/dataset') data, label = transform(data, label) ``` 6. 加载mnist数据集,并将其换为mxnet的数据格式: ```python train_data = mx.gluon.data.vision.MNIST(train=True) test_data = mx.gluon.data.vision.MNIST(train=False) train_data = train_data.transform_first(transform) test_data = test_data.transform_first(transform) ``` 7. 将自己的数据集mnist数据集合并: ```python train_data = mx.gluon.data.ConcatenatedDataset(train_data, mx.gluon.data.ArrayDataset(data, label)) ``` 8. 最后,可以像使用mnist数据集一样使用合并后的数据集训练模型。 以上就是使用mxnet将自己的数据集图片做成mnist数据集的步骤,希望能对你有所帮助。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值