利用MXNet的Module接口构建一个CNN模型

本章讲述利用MXNet构建一个简单CNN模型,并在MNIST数据集[1]上进行训练和测试。

整体结构依旧是:

  • 载入数据,并放到数据迭代器中
  • 定义网络模型
  • 定义module,指定训练位置
  • 调用fit接口,进行训练
  • 进行测试

代码如下:

#encoding:utf-8

import logging       # 对于输出每一轮的训练信息很重要
logging.getLogger().setLevel(logging.INFO)

import os
import mxnet as mx
from mxnet import nd

# 准备数据,并放到NDArrayIter迭代器中
mnist = mx.test_utils.get_mnist()

mx.random.seed(42)

batch_size = 100
train_iter = mx.io.NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)

# 定义网络
data = mx.sym.var('data')
conv1 = mx.sym.Convolution(data=data, kernel=(3,3), num_filter=20)
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))

conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=20)
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))

flatten = mx.sym.flatten(data=pool2)
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")

fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10)

cnn_symbol = mx.sym.SoftmaxOutput(data=fc2, name="softmax")

# 定义module
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
cnn_model = mx.mod.Module(symbol=cnn_symbol, context=ctx)

# 训练
cnn_model.fit(train_iter, eval_data=val_iter, optimizer='sgd', optimizer_params={'learning_rate':0.1},
              batch_end_callback = mx.callback.Speedometer(batch_size, 100),   # 100个batch以后输出一次训练信息
              eval_metric='acc', 
              num_epoch=10)  # 训练10个epochs,也就是训练集数据走10遍

# 测试
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = cnn_model.predict(test_iter)   # 测试1

test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.Accuracy()
cnn_model.score(test_iter, acc)   # 测试2
print(acc)
assert acc.get()[1] > 0.98, "Achieved accuracy (%f) is lower than expected (0.98)" % acc.get()[1]

接下来需要探索的问题是:

  1. 怎么从原始图片载入到内存中,如果数据量比较大应该如何应对?如:数据量有100G。
  2. 数据增强操作应该如何进行?如果我们采用gluon接口,那么gluon接口中就有gluon.data.vision.transforms包进行数据增强,但是采用symbol接口应该怎样增强?
  3. 怎样修改损失函数层?我们怎么去定制化损失函数?
参考

[1] http://yann.lecun.com/exdb/mnist/
[2] https://mxnet.incubator.apache.org/tutorials/python/mnist.html

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值