[MXNet]Lecture02卷积神经网络的gluon实现

gluon实现的准确率要高一点,训练较快。

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import ndarray as nd
from mxnet import autograd

ctx=mx.gpu()
def transform(data,label):
	return data.astype('float32')/255,label.astype('float32')

train_mnist=gluon.data.vision.FashionMNIST(train=True,transform=transform)
test_mnist=gluon.data.vision.FashionMNIST(train=False,transform=transform)

batch_size=256
train_data=gluon.data.DataLoader(train_mnist,batch_size,shuffle=True)
test_data=gluon.data.DataLoader(test_mnist,batch_size,shuffle=False)

net=nn.Sequential()
with net.name_scope():#主要的区别在这部分,net的构建
	net.add(
		nn.Conv2D(channels=20,kernel_size=5,activation='relu'),
		nn.MaxPool2D(pool_size=2,strides=2),
		nn.Conv2D(channels=50,kernel_size=3,activation='relu'),
		nn.MaxPool2D(pool_size=2,strides=2),
		nn.Flatten(),
		nn.Dense(128,activation='relu'),
		nn.Dense(10)
	)

net.initialize(ctx=ctx)

softmax_cross_entropy_loss=gluon.loss.SoftmaxCrossEntropyLoss()
trainer=gluon.Trainer(net.collect_params(),'sgd',{"learning_rate":0.5})

def accuracy(output,label):
	return nd.mean(output.argmax(axis=1)==label).asscalar()

def evaluate_accuracy(test_data,net):
	acc=0.
	for data, label in test_data:
		data=data.reshape((-1,1,28,28))
		data=data.as_in_context(ctx)
		label=label.as_in_context(ctx)

		output=net(data)
		label=label.as_in_context(ctx)
		acc+=accuracy(output,label)
	return acc/len(test_data)

for epoch in range(5):
	train_loss=0.
	train_acc=0.
	for data,label in train_data:
		data=data.reshape((-1,1,28,28))#手动reshape
		data=data.as_in_context(ctx)#手动复制
		label=label.as_in_context(ctx)
		with autograd.record():
			output=net(data)
			loss=softmax_cross_entropy_loss(output,label)
		loss.backward()
		trainer.step(batch_size)
		
		train_loss+=nd.mean(loss).asscalar()
		train_acc+=accuracy(output,label)
	
	test_acc=evaluate_accuracy(test_data,net)
	print("Epoch:%d, loss: %f, Train_acc: %f, Test_acc: %f"%(epoch,train_loss/len(train_data),train_acc/len(train_data),test_acc))


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值