mxnet利用下载好的mnist数据训练cnn

  这次我们先把mnist数据集从http://yann.lecun.com/exdb/mnist/下载好了,然后利用mx.io.MNISTIter来包装数据,进而用来训练cnn。
代码如下:

import mxnet as mx
import os
import logging
logging.getLogger().setLevel(logging.DEBUG)

# Training data
# logging.basicConfig(filename = os.path.join(os.getcwd(), 'log.txt'), level = logging.DEBUG) # 把log日志保存为log.txt
batch_size = 100
path = 'E:\python file\data_set\mnist/' # 数据所在的位置
train_iter = mx.io.MNISTIter(image=path+'train-images.idx3-ubyte',
                             label=path+'train-labels.idx1-ubyte',
                             batch_size=batch_size, shuffle=True)
val_iter = mx.io.MNISTIter(image=path+'t10k-images.idx3-ubyte',
                           label=path+'t10k-labels.idx1-ubyte',
                           batch_size=batch_size)
data = mx.sym.var('data')
# first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten = mx.sym.flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

# create a trainable module on GPU, here cpu is also options
mlp_model = mx.mod.Module(symbol=lenet, context=mx.gpu(0))
mlp_model.fit(train_iter,
              eval_data=val_iter,
              optimizer='sgd',
              optimizer_params={'learning_rate':0.1},
              eval_metric='acc',
              batch_end_callback=mx.callback.Speedometer(batch_size, 100),
              num_epoch=10)
acc = mx.metric.Accuracy()
mlp_model.score(val_iter, acc)
print(acc)

结果如下:

INFO:root:Epoch[0] Batch [500]  Speed: 5144.38 samples/sec  Train-accuracy=0.112475
INFO:root:Epoch[0] Train-accuracy=0.109596
INFO:root:Epoch[0] Time cost=11.737
INFO:root:Epoch[0] Validation-accuracy=0.112367
INFO:root:Epoch[1] Batch [500]  Speed: 5115.53 samples/sec  Train-accuracy=0.675250
INFO:root:Epoch[1] Train-accuracy=0.937879
INFO:root:Epoch[1] Time cost=11.742
INFO:root:Epoch[1] Validation-accuracy=0.942183
INFO:root:Epoch[2] Batch [500]  Speed: 5108.45 samples/sec  Train-accuracy=0.955669
INFO:root:Epoch[2] Train-accuracy=0.968889
INFO:root:Epoch[2] Time cost=11.723
INFO:root:Epoch[2] Validation-accuracy=0.972250
INFO:root:Epoch[3] Batch [500]  Speed: 5110.60 samples/sec  Train-accuracy=0.974232
INFO:root:Epoch[3] Train-accuracy=0.979495
INFO:root:Epoch[3] Time cost=11.719
INFO:root:Epoch[3] Validation-accuracy=0.980050
INFO:root:Epoch[4] Batch [500]  Speed: 5116.41 samples/sec  Train-accuracy=0.980998
INFO:root:Epoch[4] Train-accuracy=0.984848
INFO:root:Epoch[4] Time cost=11.720
INFO:root:Epoch[4] Validation-accuracy=0.984467
INFO:root:Epoch[5] Batch [500]  Speed: 5117.45 samples/sec  Train-accuracy=0.984671
INFO:root:Epoch[5] Train-accuracy=0.988081
INFO:root:Epoch[5] Time cost=11.742
INFO:root:Epoch[5] Validation-accuracy=0.987033
INFO:root:Epoch[6] Batch [500]  Speed: 5112.38 samples/sec  Train-accuracy=0.987006
INFO:root:Epoch[6] Train-accuracy=0.990505
INFO:root:Epoch[6] Time cost=11.729
INFO:root:Epoch[6] Validation-accuracy=0.988817
INFO:root:Epoch[7] Batch [500]  Speed: 5109.20 samples/sec  Train-accuracy=0.989022
INFO:root:Epoch[7] Train-accuracy=0.991515
INFO:root:Epoch[7] Time cost=11.776
INFO:root:Epoch[7] Validation-accuracy=0.990300
INFO:root:Epoch[8] Batch [500]  Speed: 5139.08 samples/sec  Train-accuracy=0.990220
INFO:root:Epoch[8] Train-accuracy=0.992323
INFO:root:Epoch[8] Time cost=11.679
INFO:root:Epoch[8] Validation-accuracy=0.991267
INFO:root:Epoch[9] Batch [500]  Speed: 4871.87 samples/sec  Train-accuracy=0.991297
INFO:root:Epoch[9] Train-accuracy=0.992727
INFO:root:Epoch[9] Time cost=12.316
INFO:root:Epoch[9] Validation-accuracy=0.992250
EvalMetric: {'accuracy': 0.99224999999999997}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值