训练流程
import mxnet as mx
from mxnet.gluon import loss as gloss, nn
import mxnet.gluon as gluon
from mxnet import autograd
import mxnet.ndarray as nd
import numpy as np
import mxnet.metric
class LeNet(gluon.nn.HybridBlock):
def __init__(self, classes=10,feature_size=120, **kwargs):
super(LeNet,self).__init__(**kwargs)
with self.name_scope():
self.conv1 = nn.Conv2D(channels=20, kernel_size=5, activation='relu')
self.conv2 = nn.Conv2D(channels=50, kernel_size=5, activation='relu')
self.maxpool = nn.MaxPool2D(pool_size=2, strides=2)
self.flat = nn.Flatten()
self.dense1 = nn.Dense(feature_size)
self.dense2 = nn.Dense(classes)
def hybrid_forward(self, F, x, *args, **kwargs):
x = self.maxpool(self.conv1(x))
x = self.maxpool(self.conv2(x))
ft = self.dense1(x)
output = self.dense2(ft)
return output
def transformer(data, label):
return nd.transpose(data.astype(np.float32), (2, 0, 1)).asnumpy() / 255, label.astype(np.int32)
def try_gpu():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except:
ctx = mx.cpu()
return ctx
if __name__ == '__main__':
net = LeNet()
net.hybridize()
train_data = gluon.data.DataLoader(gluon.data.vision.MNIST('./data', train=True, transform=transformer),batch_size=64, shuffle=True, last_batch='discard')
val_data = gluon.data.DataLoader(gluon.data.vision.MNIST('./data', train=False, transform=transformer),batch_size=100, shuffle=False)
ctx = try_gpu()
print(ctx)
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(),
optimizer='sgd', optimizer_params={'learning_rate': 0.01, 'wd': 5e-4})
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
epochs = 10
for epoch in range(epochs):
metric.reset()
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with autograd.record():
output = net(data)
L = loss(output, label)
L.backward()
trainer.step(data.shape[0])
metric.update([label], [output])
if i % 100 == 0 and i > 0:
name, acc = metric.get()
print('[Epoch %d Batch %d] Training: %s=%f'%(epoch, i, name, acc))
name, acc = metric.get()
print('[Epoch %d] Training: %s=%f'%(epoch, name, acc))