gluon实现的准确率要高一点,训练较快。
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import ndarray as nd
from mxnet import autograd
ctx=mx.gpu()
def transform(data,label):
return data.astype('float32')/255,label.astype('float32')
train_mnist=gluon.data.vision.FashionMNIST(train=True,transform=transform)
test_mnist=gluon.data.vision.FashionMNIST(train=False,transform=transform)
batch_size=256
train_data=gluon.data.DataLoader(train_mnist,batch_size,shuffle=True)
test_data=gluon.data.DataLoader(test_mnist,batch_size,shuffle=False)
net=nn.Sequential()
with net.name_scope():#主要的区别在这部分,net的构建
net.add(
nn.Conv2D(channels=20,kernel_size=5,activation='relu'),
nn.MaxPool2D(pool_size=2,strides=2),
nn.Conv2D(channels=50,kernel_size=3,activation='relu'),
nn.MaxPool2D(pool_size=2,strides=2),
nn.Flatten(),
nn.Dense(128,activation='relu'),
nn.Dense(10)
)
net.initialize(ctx=ctx)
softmax_cross_entropy_loss=gluon.loss.SoftmaxCrossEntropyLoss()
trainer=gluon.Trainer(net.collect_params(),'sgd',{"learning_rate":0.5})
def accuracy(output,label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(test_data,net):
acc=0.
for data, label in test_data:
data=data.reshape((-1,1,28,28))
data=data.as_in_context(ctx)
label=label.as_in_context(ctx)
output=net(data)
label=label.as_in_context(ctx)
acc+=accuracy(output,label)
return acc/len(test_data)
for epoch in range(5):
train_loss=0.
train_acc=0.
for data,label in train_data:
data=data.reshape((-1,1,28,28))#手动reshape
data=data.as_in_context(ctx)#手动复制
label=label.as_in_context(ctx)
with autograd.record():
output=net(data)
loss=softmax_cross_entropy_loss(output,label)
loss.backward()
trainer.step(batch_size)
train_loss+=nd.mean(loss).asscalar()
train_acc+=accuracy(output,label)
test_acc=evaluate_accuracy(test_data,net)
print("Epoch:%d, loss: %f, Train_acc: %f, Test_acc: %f"%(epoch,train_loss/len(train_data),train_acc/len(train_data),test_acc))