上一篇使用了BN相关算法以及建立了一个从0开始的LeNet来训练一个数据,从代码的量级来看,有点多,本章我们使用如何用mxnet里的高阶API—gluon来构建模型(包含BN)。
import mxnet.ndarray as nd
import mxnet.autograd as ag
import mxnet.gluon as gn
import mxnet as mx
import matplotlib.pyplot as plt
import sys
from mxnet import init
import os
# 继续使用FashionMNIST
mnist_train = gn.data.vision.FashionMNIST(train=True)
mnist_test = gn.data.vision.FashionMNIST(train=False)
def transform(data, label):
return data.astype("float32") / 255, label.astype("float32") # 样本归一化
'''----数据读取----'''
batch_size = 256
train_data = gn.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_data = gn.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)
ctx = mx.gpu(0)
# 定义模型
def get_net():
net = gn.nn.Sequential()
net.add(gn.nn.Conv2D(channels=6, kernel_size=5),
gn.nn.BatchNorm(), # 使用BN
gn.nn.Activation('sigmoid'),
gn.nn.MaxPool2D(pool_size=2, strides=2),
gn.nn.Conv2D(channels=16, kernel_size=5),
gn.nn.BatchNorm(),
gn.nn.Activation('sigmoid'),
gn.nn.MaxPool2D(pool_size=2, strides=2),
gn.nn.Dense(120),
gn.nn.BatchNorm(),
gn.nn.Activation('sigmoid'),
gn.nn.Dense(84),
gn.nn.BatchNorm(),
gn.nn.Activation('sigmoid'),
gn.nn.Dense(10))
net.initialize(ctx=ctx, init=init.Xavier()) # init.Xavier()随机初始化参数
return net
net = get_net()
# 定义准确率
def accuracy(output, label):
return nd.mean(output.argmax(axis=1) == label).asscalar()
def evaluate_accuracy(data_iter, net): # 定义测试集准确率
acc = 0
for data, label in data_iter:
data, label = data.as_in_context(ctx), label.as_in_context(ctx)
data, label = transform(data, label)
output = net(data.reshape(-1, 1, 28, 28))
acc += accuracy(output, label)
return acc / len(data_iter)
# softmax和交叉熵分开的话数值可能会不稳定
cross_loss = gn.loss.SoftmaxCrossEntropyLoss()
# 优化
train_step = gn.Trainer(net.collect_params(), 'sgd', {"learning_rate": 0.2})
'''---训练---'''
epochs = 50
train_avg_acc, train_avg_ls, test_avg_acc = [], [], []
for epoch in range(epochs):
train_loss = 0
train_acc = 0
for image, y in train_data:
image, y = image.as_in_context(ctx), y.as_in_context(ctx)
image, y = transform(image, y) # 类型转换,数据归一化
image = image.reshape(-1, 1, 28, 28)
with ag.record():
output = net(image)
loss = cross_loss(output, y)
loss.backward()
train_step.step(batch_size)
train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, y)
test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d, Loss:%f, Train acc:%f, Test acc:%f"
% (epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
train_avg_acc.append(train_acc / len(train_data))
train_avg_ls.append(train_loss / len(train_data))
test_avg_acc.append(test_acc)
plt.ylim(0, 1)
plt.grid() # 网格线
plt.plot(train_avg_acc)
plt.plot(train_avg_ls)
plt.plot(test_avg_acc, linestyle=':') # 虚线
plt.legend(['train acc', 'train loss', 'test acc'])
plt.show()
运行结果: