基于MXNet实现MNIST手写数字体识别

  • MNIST手写数字集:包含训练集和测试集,训练集有60000个样本,测试集有10000个样本
  • MNIST手写数字训练代码分为:训练参数配置数据读取网络结构搭建模型训练
import mxnet as mx
import argparse
import numpy as np
import gzip
import struct
import logging

# 训练参数配置
def get_args():
    parser = argparse.ArgumentParser(description='score a model on a dataset')
    parser.add_argument('--num-classes', type=int, default=10)
    parser.add_argument('--gpus', type=str, default='0')
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--num-epoch', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--save-result', type=str, default='output/')
    parser.add_argument('--save-name', type=str, default='LeNet')
    args = parser.parse_args()
    return args

# 神经网络搭建模块
def get_network(num_classes):
    data = mx.sym.Variable("data")
    conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=6, name='conv1')
    relu1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1')
    pool1 = mx.sym.Pooling(data=relu1, kernel=(2,2), stride=(2,2), pool_type='max', name='pool1')
    conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=16, name='conv2')
    relu2 = mx.sym.Activation(data=conv2, act_type='relu', name='relu2')
    pool2 = mx.sym.Pooling(data=relu2, kernel=(2,2), stride=(2,2), pool_type='max', name='pool2')
    fc1 = mx.sym.FullyConnected(data=pool2, num_hidden=120, name='fc1')
    relu3 = mx.sym.Activation(data=fc1, act_type='relu', name='relu3')
    fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=84, name='fc2')
    relu4 = mx.sym.Activation(data=fc2, act_type='relu', name='relu4')
    fc3 = mx.sym.FullyConnected(data=relu4, num_hidden=num_classes, name='fc3')
    sym = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
    return sym

# 训练代码入口
if __name__ == '__main__':
    args = get_args()
    if args.gpus:
        context = [mx.gpu(int(index)) for index in args.gpus.strip().split(',')
    else:
        context = mx.cpu()
    #----数据获取模块------
    train_data = mx.io.MNISTIter(image='train-images.idx3-ubyte', label='train-labels.idx1-ubyte', batch_size=args.batch_size, shuffle=1)
    val_data = mx.io.MNISTIter(image='t10k-images.idx3-ubyte', label='t10k-labels.idx1-ubyte', batch_size=args.batch_size, shuffle=0)

    sym = get_network(num_classes=args.num_classes)
    optimizer_params = {'learning_rate':args.lr}
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
    mod = mx.mod.Module(symbol=sym, context=context)
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    logger.addHandler(stream_handler)
    file_handler = logging.FileHandler('output/train.log')
    logger.addHandler(file_handler)
    logger.info(args)

    checkpoint = mx.callback.do_checkpoint(prefix=args.save_result+args.save_name)
    batch_callback = mx.callback.Speedometer(args.batch_size, 1000)
    mod.fit(train_data=train_data, eval_data=val_data, eval_metric='acc', optimizer_params=optimizer_params, optimizer='sgd', batch_end_callback=batch_callback, initializer=initializer, num_epoch=args.num_epoch, epoch_end_callback=checkpoint)

  • MNIST手写数字测试代码分为:模型导入数据读取模型推理输出
import mxnet as mx
import numpy as np
# 模型导入
def load_model(model_prefix, index, context, data_shapes, label_shapes):
    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, index)
    model = mx.mod.Module(symbol=sym, context=context)
    model.bind(data_shapes=data_shapes, label_shapes=label_shapes, for_training=False)
    model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)
    return model
# 数据读取
def load_data(data_path):
    data = mx.image.imread(data_path, flag=0) # flag为0代表为灰度图
    cla_cast_aug = mx.image.CastAug() # 参数转为32位浮点数
    cla_resize_aug = mx.image.ForceResizeAug(size=[28,28]) # 图片resize成要求大小
    cla_augmenters = [cla_cast_aug, cla_resize_aug]
    
    for aug in cla_augmenters:
        data = aug(data)
    data = mx.nd.transpose(data, axes=(2, 0, 1)) # 数据由BGR转为RGB
    data = mx.nd.expand_dims(data, axis=0)
    data = mx.io.DataBatch([data])
    return data
# 模型推理输出
def get_output(model, data):
    model.forward(data)
    cla_prob = model.get_outputs()[0][0].asnumpy() # model.get_outputs返回结果是三维的,第一维度是任务维度(这里是分类),第二维度是batch_size,第三维度是one_hot结果
    cla_label = np.argmax(cla_prob)
    return cla_label
# 测试代码入口
if __name__ == '__main__':
    model_prefix='output/LeNet'
    index = 10
    context = mx.gpu(0)
    data_shapes=[('data', (1, 1, 28, 28))]
    label_shapes=[('softmax_label', (1,))]
    model = load_model(model_prefix, index, context, data_shapes, label_shapes)
    
    data_path = 'test_image/test1.png'
    data = load_data(data_path)
    cla_label = get_output(model, data)
    print("predict result:{}".format(cla_label))
   
    

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值