import mxnet as mx
def mlp():
data=mx.sym.Variable('data')
fc1=mx.sym.FullyConnected(data, name="fc1", num_hidden=512)
act1=mx.sym.Activation(fc1, name="relu1", act_type="relu")
fc2=mx.sym.FullyConnected(act1, name="fc2", num_hidden=512)
act2=mx.sym.Activation(fc2, name="relu2", act_type="relu")
fc3=mx.sym.FullyConnected(act2, name="fc3", num_hidden=10)
mlp=mx.sym.SoftmaxOutput(fc3, name="softmax")
return mlp
if __name__=="__main__":
num_epoch=3
batch_size=100
train_dataiter=mx.io.CSVIter(data_csv="mnist.train", data_shape=(28, 28), label_csv="label.train", label_shape=(1,), batch_size=batch_size)
val_dataiter=mx.io.CSVIter(data_csv="mnist.val", data_shape=(28, 28), label_csv="label.val", label_shape=(1,), batch_size=batch_size)
mlp=mlp()
mod=mx.mod.Module(mlp)
#bind and init_params
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
metric = mx.metric.create('acc')
for i_epoch in range(num_epoch):
for i_iter, batch in enumerate(train_dataiter):
mod.forward(batch)
mod.update_metric(metric, batch.label)
mod.backward()
mod.update()
for name, val in metric.get_name_value():
print('epoch %03d: %s=%f' % (i_epoch, name, val))
metric.reset()
train_dataiter.reset()
the format of cvs data is separated by comma.