本章讲述利用MXNet构建一个简单CNN模型,并在MNIST数据集[1]上进行训练和测试。
整体结构依旧是:
- 载入数据,并放到数据迭代器中
- 定义网络模型
- 定义module,指定训练位置
- 调用fit接口,进行训练
- 进行测试
代码如下:
#encoding:utf-8
import logging # 对于输出每一轮的训练信息很重要
logging.getLogger().setLevel(logging.INFO)
import os
import mxnet as mx
from mxnet import nd
# 准备数据,并放到NDArrayIter迭代器中
mnist = mx.test_utils.get_mnist()
mx.random.seed(42)
batch_size = 100
train_iter = mx.io.NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
# 定义网络
data = mx.sym.var('data')
conv1 = mx.sym.Convolution(data=data, kernel=(3,3), num_filter=20)
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))
conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=20)
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))
flatten = mx.sym.flatten(data=pool2)
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10)
cnn_symbol = mx.sym.SoftmaxOutput(data=fc2, name="softmax")
# 定义module
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
cnn_model = mx.mod.Module(symbol=cnn_symbol, context=ctx)
# 训练
cnn_model.fit(train_iter, eval_data=val_iter, optimizer='sgd', optimizer_params={'learning_rate':0.1},
batch_end_callback = mx.callback.Speedometer(batch_size, 100), # 100个batch以后输出一次训练信息
eval_metric='acc',
num_epoch=10) # 训练10个epochs,也就是训练集数据走10遍
# 测试
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = cnn_model.predict(test_iter) # 测试1
test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.Accuracy()
cnn_model.score(test_iter, acc) # 测试2
print(acc)
assert acc.get()[1] > 0.98, "Achieved accuracy (%f) is lower than expected (0.98)" % acc.get()[1]
接下来需要探索的问题是:
- 怎么从原始图片载入到内存中,如果数据量比较大应该如何应对?如:数据量有100G。
- 数据增强操作应该如何进行?如果我们采用gluon接口,那么gluon接口中就有gluon.data.vision.transforms包进行数据增强,但是采用symbol接口应该怎样增强?
- 怎样修改损失函数层?我们怎么去定制化损失函数?
参考
[1] http://yann.lecun.com/exdb/mnist/
[2] https://mxnet.incubator.apache.org/tutorials/python/mnist.html