文章标题

import mxnet as mx

def mlp():
    data=mx.sym.Variable('data')
    fc1=mx.sym.FullyConnected(data, name="fc1", num_hidden=512)
    act1=mx.sym.Activation(fc1, name="relu1", act_type="relu")
    fc2=mx.sym.FullyConnected(act1, name="fc2", num_hidden=512)
    act2=mx.sym.Activation(fc2, name="relu2", act_type="relu")
    fc3=mx.sym.FullyConnected(act2, name="fc3", num_hidden=10)
    mlp=mx.sym.SoftmaxOutput(fc3, name="softmax")
    return mlp 

if __name__=="__main__":
    num_epoch=3
    batch_size=100
    train_dataiter=mx.io.CSVIter(data_csv="mnist.train", data_shape=(28, 28), label_csv="label.train", label_shape=(1,), batch_size=batch_size)
    val_dataiter=mx.io.CSVIter(data_csv="mnist.val", data_shape=(28, 28), label_csv="label.val", label_shape=(1,), batch_size=batch_size)
    mlp=mlp()
    ###config model_args
    model_args = dict()
    ##the first parameter is the number of the batch_num
    model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(500, 0.9)
    model=mx.model.FeedForward(
            ctx=mx.gpu(0),
            symbol=mlp,
            num_epoch=5,
            learning_rate=0.01,
            momentum=0.9,
            wd=0.01,
            **model_args)

    ####cofig log file
    import logging
    LOG_FILE='mnist.log'
    logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG)

    model.fit(
            X=train_dataiter,
            eval_data=val_dataiter,
            batch_end_callback=mx.callback.Speedometer(batch_size, 50),)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值