这是来自MXNet官网里面的一个例子,利用module包[1]来构建一个多层感知机,并在UCI letter recognition[2]中进行训练。利用module包来训练网络,可以采用两种接口:中层接口和高层接口。高层接口可以看做是对中层接口的一种封装。
利用module进行网络训练一般包括以下四个步骤:
- 载入数据
一般将数据载入到内存中,可以是全体训练集数据,也可以是部分数据集。 - 定义网络
利用symbol接口定义网络模型。 - 创建module模块
指定训练的设备,可以是gpu或是cpu,或是多卡gpu。 - 调用训练接口
定义好优化方法、评价指标、模型存储方法等。
载入数据
载入数据的目的是将训练集和验证集载入到内存中。因为,我们训练网络的数据集比较大,所以我们一般使用一个迭代器,这样每次只用载入batch大小的数据。
下面是下载该数据集,并将数据集按照8:2的比例分为训练集和测试集。
import logging
logging.getLogger().setLevel(logging.INFO)
import mxnet as mx
import numpy as np
mx.random.seed(1234)
fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',)[:,1:] #data.shape = (10000L, 16L)
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])
batch_size = 32
ntrain = int(data.shape[0]*0.8)
train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain],batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:],batch_size)
定义网络
利用symbol接口定义网络模型,包括网络的输入和标签。
net = mx.sym.Variable("data")
net = mx.sym.FullyConnected(net, name="fc1", num_hidden=64)
net = mx.sym.Activation(net, name="relu1", act_type="relu")
net = mx.sym.FullyConnected(net, name="fc2", num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name="softmax")
mx.viz.plot_network(net) # 显示网路模型
创建module模块
mod = mx.mod.Module(symbol=net, context=mx.cpu(), # 指定cpu或是gpu
data_names=['data'],
label_names=['softmax_label'])
中层接口
采用中层接口的好处是便于debug。因为,中层接口需要明确写出forward和backward。如下代码:
mod.bind(data_shapes=train_iter.provide_data,
label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Uniform(scale=.1))
mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),))
metric=mx.metric.create('acc')
for epoch in range(5):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch, is_train=True)
mod.update_metric(metric, batch.label)
mod.backward()
mod.update()
print('Epoch %d, Training %s' % (epoch, metric.get()))
采用高层接口
train_iter.reset()
mod = mx.mod.Module(symbol=net, context=mx.cpu(),
data_names=['data'], label_name=['softmax_label'])
mod.fit(train_iter, eval_data=val_iter,
optimizer='sgd', optimizer_params={'learning_rate':0.1},
eval_metric='acc', num_epoch=60)
可以看出,module为了方便,将训练、预测和评价等操作都进行了封装。与中层接口不同,需要一步一步来操作,module包直接调用了fit函数接口就完成相同的操作。
预测和指标评价
可以采用两种方式:
# 方式一:直接获得预测结果
y = mod.predict(val_iter)
# 方式二:直接获得指标结果
score = mod.score(val_iter, ['acc'])
print("Accuracy score is %f" % (score[0][1]))
使用第二种方式,我们可以直接获得指标结果,而指标我们也可以根据实际情况进行替换。如top_k_acc,F1,RMSE等等。
保存和载入模型
使用checkpoint callback来控制每一轮迭代后是否需要自动保存模型,如下:
model_prefix = "mx_mlp"
checkpoint = mx.callback.do_checkpoint(model_prefix)
mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)
为了载入模型,我们需要调用load_checkpoint函数。这个函数载入symbol以及对应的参数,然后我们将这些模型和参数来初始化module。如下:
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson()==net.tojson()
mod.set_params(arg_params, aux_params)
如果我们想从某个节点训练模型,那么我们可以直接调用fit函数。在fit函数中,我们可以直接载入这些参数,而不是初始化。另外,我们可以设置begin_epoch参数,让模型知道我们是从之前某个节点开始训练。
mod=mx.mod.Module(symbol=sym)
mod.fit(train_iter, num_epoch=21,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=3)
参考
[1] https://mxnet.incubator.apache.org/tutorials/basic/module.html
[2] https://archive.ics.uci.edu/ml/datasets/letter+recognition