- MNIST手写数字集:包含训练集和测试集,训练集有60000个样本,测试集有10000个样本。
- MNIST手写数字训练代码分为:训练参数配置、数据读取、网络结构搭建、模型训练
import mxnet as mx
import argparse
import numpy as np
import gzip
import struct
import logging
# 训练参数配置
def get_args():
parser = argparse.ArgumentParser(description='score a model on a dataset')
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--num-epoch', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
parser.add_argument('--save-result', type=str, default='output/')
parser.add_argument('--save-name', type=str, default='LeNet')
args = parser.parse_args()
return args
# 神经网络搭建模块
def get_network(num_classes):
data = mx.sym.Variable("data")
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=6, name='conv1')
relu1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1')
pool1 = mx.sym.Pooling(data=relu1, kernel=(2,2), stride=(2,2), pool_type='max', name='pool1')
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=16, name='conv2')
relu2 = mx.sym.Activation(data=conv2, act_type='relu', name='relu2')
pool2 = mx.sym.Pooling(data=relu2, kernel=(2,2), stride=(2,2), pool_type='max', name='pool2')
fc1 = mx.sym.FullyConnected(data=pool2, num_hidden=120, name='fc1')
relu3 = mx.sym.Activation(data=fc1, act_type='relu', name='relu3')
fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=84, name='fc2')
relu4 = mx.sym.Activation(data=fc2, act_type='relu', name='relu4')
fc3 = mx.sym.FullyConnected(data=relu4, num_hidden=num_classes, name='fc3')
sym = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
return sym
# 训练代码入口
if __name__ == '__main__':
args = get_args()
if args.gpus:
context = [mx.gpu(int(index)) for index in args.gpus.strip().split(',')
else:
context = mx.cpu()
#----数据获取模块------
train_data = mx.io.MNISTIter(image='train-images.idx3-ubyte', label='train-labels.idx1-ubyte', batch_size=args.batch_size, shuffle=1)
val_data = mx.io.MNISTIter(image='t10k-images.idx3-ubyte', label='t10k-labels.idx1-ubyte', batch_size=args.batch_size, shuffle=0)
sym = get_network(num_classes=args.num_classes)
optimizer_params = {'learning_rate':args.lr}
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
mod = mx.mod.Module(symbol=sym, context=context)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)
file_handler = logging.FileHandler('output/train.log')
logger.addHandler(file_handler)
logger.info(args)
checkpoint = mx.callback.do_checkpoint(prefix=args.save_result+args.save_name)
batch_callback = mx.callback.Speedometer(args.batch_size, 1000)
mod.fit(train_data=train_data, eval_data=val_data, eval_metric='acc', optimizer_params=optimizer_params, optimizer='sgd', batch_end_callback=batch_callback, initializer=initializer, num_epoch=args.num_epoch, epoch_end_callback=checkpoint)
- MNIST手写数字测试代码分为:模型导入、数据读取、模型推理输出
import mxnet as mx
import numpy as np
# 模型导入
def load_model(model_prefix, index, context, data_shapes, label_shapes):
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, index)
model = mx.mod.Module(symbol=sym, context=context)
model.bind(data_shapes=data_shapes, label_shapes=label_shapes, for_training=False)
model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)
return model
# 数据读取
def load_data(data_path):
data = mx.image.imread(data_path, flag=0) # flag为0代表为灰度图
cla_cast_aug = mx.image.CastAug() # 参数转为32位浮点数
cla_resize_aug = mx.image.ForceResizeAug(size=[28,28]) # 图片resize成要求大小
cla_augmenters = [cla_cast_aug, cla_resize_aug]
for aug in cla_augmenters:
data = aug(data)
data = mx.nd.transpose(data, axes=(2, 0, 1)) # 数据由BGR转为RGB
data = mx.nd.expand_dims(data, axis=0)
data = mx.io.DataBatch([data])
return data
# 模型推理输出
def get_output(model, data):
model.forward(data)
cla_prob = model.get_outputs()[0][0].asnumpy() # model.get_outputs返回结果是三维的,第一维度是任务维度(这里是分类),第二维度是batch_size,第三维度是one_hot结果
cla_label = np.argmax(cla_prob)
return cla_label
# 测试代码入口
if __name__ == '__main__':
model_prefix='output/LeNet'
index = 10
context = mx.gpu(0)
data_shapes=[('data', (1, 1, 28, 28))]
label_shapes=[('softmax_label', (1,))]
model = load_model(model_prefix, index, context, data_shapes, label_shapes)
data_path = 'test_image/test1.png'
data = load_data(data_path)
cla_label = get_output(model, data)
print("predict result:{}".format(cla_label))