csviter mxnet

import mxnet as mx

def mlp():
    data=mx.sym.Variable('data')
    fc1=mx.sym.FullyConnected(data, name="fc1", num_hidden=512)
    act1=mx.sym.Activation(fc1, name="relu1", act_type="relu")
    fc2=mx.sym.FullyConnected(act1, name="fc2", num_hidden=512)
    act2=mx.sym.Activation(fc2, name="relu2", act_type="relu")
    fc3=mx.sym.FullyConnected(act2, name="fc3", num_hidden=10)
    mlp=mx.sym.SoftmaxOutput(fc3, name="softmax")
    return mlp 

if __name__=="__main__":
    num_epoch=3
    batch_size=100
    train_dataiter=mx.io.CSVIter(data_csv="mnist.train", data_shape=(28, 28), label_csv="label.train", label_shape=(1,), batch_size=batch_size)
    val_dataiter=mx.io.CSVIter(data_csv="mnist.val", data_shape=(28, 28), label_csv="label.val", label_shape=(1,), batch_size=batch_size)
    mlp=mlp()
    mod=mx.mod.Module(mlp)
    #bind and init_params
    mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
    mod.init_params()

    mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
    metric = mx.metric.create('acc')

    for i_epoch in range(num_epoch):
        for i_iter, batch in enumerate(train_dataiter):
            mod.forward(batch)
            mod.update_metric(metric, batch.label)

            mod.backward()
            mod.update()

        for name, val in metric.get_name_value():
            print('epoch %03d: %s=%f' % (i_epoch, name, val))
        metric.reset()
        train_dataiter.reset()

the format of cvs data is separated by comma.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值