MXNet tutorial——Train the neural network

数据集:FashionMNIST

  1. 导入类库
from mxnet import nd, gluon, init, autograd
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms
import time
  1. 准备数据
// 训练集
mnist_train = datasets.FashionMNIST(train=True)
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.13, 0.31)])
mnist_train = mnist_train.transform_first(transformer)
// 验证集
mnist_valid = gluon.data.vision.FashionMNIST(train=False)
valid_data = gluon.data.DataLoader(
    mnist_valid.transform_first(transformer),
    batch_size=batch_size, num_workers=4)
  1. 定义模型
// 定义模型
net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Flatten(),
        nn.Dense(120, activation="relu"),
        nn.Dense(84, activation="relu"),
        nn.Dense(10))
// 参数初始化
net.initialize(init=init.Xavier())
// 定义损失函数(交叉熵损失函数)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
// 定义优化器(sgd优化)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
  1. 训练模型
def acc(output, label):
    return (output.argmax(axis=1) ==
            label.astype('float32')).mean().asscalar()
  
for epoch in range(10):
    train_loss, train_acc, valid_acc = 0., 0., 0.
    tic = time.time()
    for data, label in train_data:
        # 前向传播+反向传播
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        # 更新参数
        trainer.step(batch_size)
        # 计算训练精度
        train_loss += loss.mean().asscalar()
        train_acc += acc(output, label)
    # 计算验证精度
    for data, label in valid_data:
        valid_acc += acc(net(data), label)
    print("Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" % (
            epoch, train_loss/len(train_data), train_acc/len(train_data),
            valid_acc/len(valid_data), time.time()-tic))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值