mxnet多层感知机训练MNIST数据集详解【转】

来自:http://www.cnblogs.com/Mu001999/p/6221093.html

#导入需要的模块
import numpy as np #numpy只保存数值,用于数值运算,解决Python标准库中的list只能保存对象的指针的问题
import os #本例子中没有使用到
import gzip #使用zlib来压缩和解压缩数据文件,读写gzip文件
import struct #通过引入struct模块来处理图片中的二进制数据
import mxnet as mx #引入MXNet包
import logging #引入logging包记录日志

#利用MNIST数据集进行训练

def read_data(label_url,image_url): #定义读取数据的函数
    with gzip.open(label_url) as flbl: #解压标签包
        magic, num = struct.unpack(">II",flbl.read(8)) #采用Big Endian的方式读取两个int类型的数据,且参考MNIST官方格式介绍,magic即为magic number (MSB first) 用于表示文件格式,num即为文件夹内包含的数据的数量
        label = np.fromstring(flbl.read(),dtype=np.int8) #将标签包中的每一个二进制数据转化成其对应的十进制数据,且转换后的数据格式为int8(-128 to 127)格式,返回一个数组
    with gzip.open(image_url,'rb') as fimg: #已只读形式解压图像包
        magic, num, rows, cols = struct.unpack(">IIII",fimg.read(16)) #采用Big Endian的方式读取四个int类型数据,且参考MNIST官方格式介绍,magic和num上同,rows和cols即表示图片的行数和列数
        image = np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols) #将图片包中的二进制数据读取后转换成无符号的int8格式的数组,并且以标签总个数,行数,列数重塑成一个新的多维数组
    return (label,image) #返回读取成功的label数组和image数组
#且fileobject.read(size)的时候是按照流的方式读取(可test)

(train_lbl, train_img) = read_data('mnist/train-labels-idx1-ubyte.gz','mnist/train-images-idx3-ubyte.gz') #构建训练数据
(val_lbl, val_img) = read_data('mnist/t10k-labels-idx1-ubyte.gz','mnist/t10k-images-idx3-ubyte.gz') #构建测试数据

def to4d(img): #定义一个函数用于生成四维矩阵
    return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255 #将图像包中的数组以标签总个数,图像通道数(MNIST数据集为黑白数据集故只为1),行数,列数重塑后复制为一个数据类型为float32的新的四维矩阵,且其中的元素值都除以255后转化为0-1的浮点值

batch_size = 100 #定义每次处理数据的数量为100
train_iter = mx.io.NDArrayIter(to4d(train_img),train_lbl,batch_size,shuffle=True) #构建训练数据迭代器,且其中shuffle表示采用可拖动的方式,意味着可以将在早期已经训练过的数据在后面再次训练
val_iter = mx.io.NDArrayIter(to4d(val_img),val_lbl,batch_size) #构建测试数据迭代器

#创建多层网络模型
data = mx.sym.Variable('data') #创建一个用于输入数据的PlaceHolder变量(占位符)
data = mx.sym.Flatten(data=data) #将data中的四维数据转化为二维数据且其中一维为每次处理数据的数量,第二维即为每张图片的图像通道数×长×宽(即为其像素点个数×图像通道数)
fc1 = mx.sym.FullyConnected(data=data,name='fc1',num_hidden=128) #创建第一层全连接层,输入数据为data,num_hidden表示该隐藏层有128个用于输出的节点
act1 = mx.sym.Activation(data=fc1,name='relu1',act_type='relu') #为第一层全连接层设定一个Relu激活函数,输入数据为fc1
fc2 = mx.sym.FullyConnected(data=act1,name='fc2',num_hidden=64) #创建第二层全连接层,输入数据为act1,num_hidden表示该隐藏层有64个用于输出的节点
act2 = mx.sym.Activation(data=fc2,name='relu2',act_type='relu') #为第一层全连接层设定一个Relu激活函数,输入数据为fc2
fc3 = mx.sym.FullyConnected(data=act2,Name='fc3',num_hidden=10) #创建第三层全连接层,输入数据为act2,num_hidden表示该隐藏层有10个用于输出的节点
mlp = mx.sym.SoftmaxOutput(data=fc3,name='softmax') #对输入的数据执行softmax变换,并且通过利用logloss执行BP算法

logging.getLogger().setLevel(logging.DEBUG) #返回作为层次结构根记录器的记录器,且记录等级作为DEBUG

#构建前馈神经网络模型
model = mx.model.FeedForward(
    symbol = mlp, #使网络结构为构建好的mlp
    num_epoch = 20, #数据的训练次数为20
    learning_rate = 0.1 #使模型按照学习率为0.1进行训练
)
#数据拟合,训练模型
model.fit(
    X = train_iter, #设置训练迭代器
    eval_data = val_iter, #设置测试迭代器
    batch_end_callback = mx.callback.Speedometer(batch_size,200) #在每一批epoches结尾时调用,打印logging信息(每经过200个batch_size打印logging)
)

结果:

chenqy@cqy:~/test_mxnet$ python mnist.py 
mnist.py:44: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
  learning_rate = 0.1
/home/chenqy/anaconda3/lib/python3.6/site-packages/mxnet/model.py:526: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
  self.initializer(k, v)
INFO:root:Start training with [cpu(0)]
INFO:root:Epoch[0] Batch [200]  Speed: 28654.44 samples/sec     accuracy=0.110850
INFO:root:Epoch[0] Batch [400]  Speed: 28355.02 samples/sec     accuracy=0.111300
INFO:root:Epoch[0] Batch [600]  Speed: 28450.97 samples/sec     accuracy=0.225650
INFO:root:Epoch[0] Resetting Data Iterator
INFO:root:Epoch[0] Time cost=2.136
INFO:root:Epoch[0] Validation-accuracy=0.462500
INFO:root:Epoch[1] Batch [200]  Speed: 28150.14 samples/sec     accuracy=0.617100
INFO:root:Epoch[1] Batch [400]  Speed: 28452.33 samples/sec     accuracy=0.797150
INFO:root:Epoch[1] Batch [600]  Speed: 27642.93 samples/sec     accuracy=0.840200
INFO:root:Epoch[1] Resetting Data Iterator
INFO:root:Epoch[1] Time cost=2.142
INFO:root:Epoch[1] Validation-accuracy=0.863000
INFO:root:Epoch[2] Batch [200]  Speed: 28472.69 samples/sec     accuracy=0.870900
INFO:root:Epoch[2] Batch [400]  Speed: 28021.98 samples/sec     accuracy=0.895400
INFO:root:Epoch[2] Batch [600]  Speed: 29820.60 samples/sec     accuracy=0.912850
INFO:root:Epoch[2] Resetting Data Iterator
INFO:root:Epoch[2] Time cost=2.091
INFO:root:Epoch[2] Validation-accuracy=0.913300
INFO:root:Epoch[3] Batch [200]  Speed: 28586.32 samples/sec     accuracy=0.921850
INFO:root:Epoch[3] Batch [400]  Speed: 26254.28 samples/sec     accuracy=0.931350
INFO:root:Epoch[3] Batch [600]  Speed: 28702.82 samples/sec     accuracy=0.939500
INFO:root:Epoch[3] Resetting Data Iterator
INFO:root:Epoch[3] Time cost=2.163
INFO:root:Epoch[3] Validation-accuracy=0.939200
INFO:root:Epoch[4] Batch [200]  Speed: 28974.95 samples/sec     accuracy=0.944500
INFO:root:Epoch[4] Batch [400]  Speed: 26773.51 samples/sec     accuracy=0.948150
INFO:root:Epoch[4] Batch [600]  Speed: 30578.17 samples/sec     accuracy=0.953950
INFO:root:Epoch[4] Resetting Data Iterator
INFO:root:Epoch[4] Time cost=2.098
INFO:root:Epoch[4] Validation-accuracy=0.949100
INFO:root:Epoch[5] Batch [200]  Speed: 28481.22 samples/sec     accuracy=0.956750
INFO:root:Epoch[5] Batch [400]  Speed: 26666.94 samples/sec     accuracy=0.958000
INFO:root:Epoch[5] Batch [600]  Speed: 27876.00 samples/sec     accuracy=0.961450
INFO:root:Epoch[5] Resetting Data Iterator
INFO:root:Epoch[5] Time cost=2.179
INFO:root:Epoch[5] Validation-accuracy=0.955500
INFO:root:Epoch[6] Batch [200]  Speed: 30656.54 samples/sec     accuracy=0.963450
INFO:root:Epoch[6] Batch [400]  Speed: 28489.96 samples/sec     accuracy=0.965100
INFO:root:Epoch[6] Batch [600]  Speed: 20908.39 samples/sec     accuracy=0.967750
INFO:root:Epoch[6] Resetting Data Iterator
INFO:root:Epoch[6] Time cost=2.316
INFO:root:Epoch[6] Validation-accuracy=0.962900
INFO:root:Epoch[7] Batch [200]  Speed: 28951.96 samples/sec     accuracy=0.968050
INFO:root:Epoch[7] Batch [400]  Speed: 30919.49 samples/sec     accuracy=0.970350
INFO:root:Epoch[7] Batch [600]  Speed: 28214.94 samples/sec     accuracy=0.971500
INFO:root:Epoch[7] Resetting Data Iterator
INFO:root:Epoch[7] Time cost=2.051
INFO:root:Epoch[7] Validation-accuracy=0.965700
INFO:root:Epoch[8] Batch [200]  Speed: 30160.24 samples/sec     accuracy=0.973700
INFO:root:Epoch[8] Batch [400]  Speed: 28007.56 samples/sec     accuracy=0.974350
INFO:root:Epoch[8] Batch [600]  Speed: 28068.66 samples/sec     accuracy=0.974650
INFO:root:Epoch[8] Resetting Data Iterator
INFO:root:Epoch[8] Time cost=2.094
INFO:root:Epoch[8] Validation-accuracy=0.967800
INFO:root:Epoch[9] Batch [200]  Speed: 28675.15 samples/sec     accuracy=0.976800
INFO:root:Epoch[9] Batch [400]  Speed: 28443.64 samples/sec     accuracy=0.978250
INFO:root:Epoch[9] Batch [600]  Speed: 27833.02 samples/sec     accuracy=0.977750
INFO:root:Epoch[9] Resetting Data Iterator
INFO:root:Epoch[9] Time cost=2.124
INFO:root:Epoch[9] Validation-accuracy=0.969700
INFO:root:Epoch[10] Batch [200] Speed: 27748.34 samples/sec     accuracy=0.979850
INFO:root:Epoch[10] Batch [400] Speed: 28045.86 samples/sec     accuracy=0.981200
INFO:root:Epoch[10] Batch [600] Speed: 27743.39 samples/sec     accuracy=0.981150
INFO:root:Epoch[10] Resetting Data Iterator
INFO:root:Epoch[10] Time cost=2.159
INFO:root:Epoch[10] Validation-accuracy=0.971300
INFO:root:Epoch[11] Batch [200] Speed: 28185.11 samples/sec     accuracy=0.982350
INFO:root:Epoch[11] Batch [400] Speed: 28095.23 samples/sec     accuracy=0.983050
INFO:root:Epoch[11] Batch [600] Speed: 27157.69 samples/sec     accuracy=0.983950
INFO:root:Epoch[11] Resetting Data Iterator
INFO:root:Epoch[11] Time cost=2.162
INFO:root:Epoch[11] Validation-accuracy=0.971400
INFO:root:Epoch[12] Batch [200] Speed: 28228.54 samples/sec     accuracy=0.984450
INFO:root:Epoch[12] Batch [400] Speed: 27552.26 samples/sec     accuracy=0.985000
INFO:root:Epoch[12] Batch [600] Speed: 28388.36 samples/sec     accuracy=0.985550
INFO:root:Epoch[12] Resetting Data Iterator
INFO:root:Epoch[12] Time cost=2.144
INFO:root:Epoch[12] Validation-accuracy=0.972100
INFO:root:Epoch[13] Batch [200] Speed: 28875.11 samples/sec     accuracy=0.987250
INFO:root:Epoch[13] Batch [400] Speed: 26348.66 samples/sec     accuracy=0.986850
INFO:root:Epoch[13] Batch [600] Speed: 28737.18 samples/sec     accuracy=0.987150
INFO:root:Epoch[13] Resetting Data Iterator
INFO:root:Epoch[13] Time cost=2.152
INFO:root:Epoch[13] Validation-accuracy=0.972000
INFO:root:Epoch[14] Batch [200] Speed: 27254.97 samples/sec     accuracy=0.988750
INFO:root:Epoch[14] Batch [400] Speed: 28949.99 samples/sec     accuracy=0.988300
INFO:root:Epoch[14] Batch [600] Speed: 28010.31 samples/sec     accuracy=0.989200
INFO:root:Epoch[14] Resetting Data Iterator
INFO:root:Epoch[14] Time cost=2.143
INFO:root:Epoch[14] Validation-accuracy=0.971800
INFO:root:Epoch[15] Batch [200] Speed: 27872.34 samples/sec     accuracy=0.990150
INFO:root:Epoch[15] Batch [400] Speed: 27604.11 samples/sec     accuracy=0.989900
INFO:root:Epoch[15] Batch [600] Speed: 27109.95 samples/sec     accuracy=0.990900
INFO:root:Epoch[15] Resetting Data Iterator
INFO:root:Epoch[15] Time cost=2.185
INFO:root:Epoch[15] Validation-accuracy=0.972400
INFO:root:Epoch[16] Batch [200] Speed: 27911.90 samples/sec     accuracy=0.991800
INFO:root:Epoch[16] Batch [400] Speed: 28186.87 samples/sec     accuracy=0.991300
INFO:root:Epoch[16] Batch [600] Speed: 29020.00 samples/sec     accuracy=0.992050
INFO:root:Epoch[16] Resetting Data Iterator
INFO:root:Epoch[16] Time cost=2.120
INFO:root:Epoch[16] Validation-accuracy=0.972900
INFO:root:Epoch[17] Batch [200] Speed: 24626.20 samples/sec     accuracy=0.992850
INFO:root:Epoch[17] Batch [400] Speed: 25938.36 samples/sec     accuracy=0.992200
INFO:root:Epoch[17] Batch [600] Speed: 25045.08 samples/sec     accuracy=0.992900
INFO:root:Epoch[17] Resetting Data Iterator
INFO:root:Epoch[17] Time cost=2.386
INFO:root:Epoch[17] Validation-accuracy=0.973300
INFO:root:Epoch[18] Batch [200] Speed: 25552.00 samples/sec     accuracy=0.994200
INFO:root:Epoch[18] Batch [400] Speed: 25847.90 samples/sec     accuracy=0.993250
INFO:root:Epoch[18] Batch [600] Speed: 26043.37 samples/sec     accuracy=0.993550
INFO:root:Epoch[18] Resetting Data Iterator
INFO:root:Epoch[18] Time cost=2.330
INFO:root:Epoch[18] Validation-accuracy=0.974400
INFO:root:Epoch[19] Batch [200] Speed: 26778.91 samples/sec     accuracy=0.994900
INFO:root:Epoch[19] Batch [400] Speed: 25334.35 samples/sec     accuracy=0.994250
INFO:root:Epoch[19] Batch [600] Speed: 26160.49 samples/sec     accuracy=0.994500
INFO:root:Epoch[19] Resetting Data Iterator
INFO:root:Epoch[19] Time cost=2.305
INFO:root:Epoch[19] Validation-accuracy=0.975000




  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值