MXNET深度学习框架-12-使用gluon实现LeNet-5

上一章从0开始实现了一个简单的CNN,但是有点麻烦,接下来使用gluon中的api来实现经典的LeNet-5:
代码如下:

import mxnet.ndarray as nd
import mxnet.autograd as ag
import mxnet.gluon as gn
import mxnet as mx
import matplotlib.pyplot as plt
import sys
from mxnet import init
import os
# 继续使用FashionMNIST
mnist_train = gn.data.vision.FashionMNIST(train=True)
mnist_test = gn.data.vision.FashionMNIST(train=False)


def transform(data, label):
    return data.astype("float32") / 255, label.astype("float32")  # 样本归一化

'''----数据读取----'''
batch_size = 256

train_data = gn.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_data = gn.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)
ctx = mx.gpu(0)

# 定义模型
def get_net():
    net = gn.nn.Sequential()
    net.add(gn.nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
        gn.nn.MaxPool2D(pool_size=2, strides=2),
        gn.nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
        gn.nn.MaxPool2D(pool_size=2, strides=2),
        gn.nn.Dense(120, activation='sigmoid'),
        gn.nn.Dense(84, activation='sigmoid'),
        gn.nn.Dense(10))
    net.initialize(ctx=ctx, init=init.Xavier())  # init.Xavier()随机初始化参数
    return net
net=get_net()
# 定义准确率
def accuracy(output,label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()

def evaluate_accuracy(data_iter,net):# 定义测试集准确率
    acc=0
    for data,label in data_iter:
        data, label = data.as_in_context(ctx), label.as_in_context(ctx)
        data,label=transform(data,label)
        output=net(data.reshape(-1,1,28,28))
        acc+=accuracy(output,label)
    return acc/len(data_iter)

# softmax和交叉熵分开的话数值可能会不稳定
cross_loss=gn.loss.SoftmaxCrossEntropyLoss()
# 优化
train_step=gn.Trainer(net.collect_params(),'sgd',{"learning_rate":0.9})
'''---训练---'''
epochs=50
train_avg_acc,train_avg_ls,test_avg_acc=[],[],[]
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0
    for image,y in train_data:
        image, y = image.as_in_context(ctx), y.as_in_context(ctx)
        image, y = transform(image, y)  # 类型转换,数据归一化
        image=image.reshape(-1,1,28,28)
        with ag.record():
            output=net(image)
            loss=cross_loss(output,y)
        loss.backward()
        train_step.step(batch_size)
        train_loss += nd.mean(loss).asscalar()
        train_acc += accuracy(output, y)
    test_acc = evaluate_accuracy(test_data, net)
    print("Epoch %d, Loss:%f, Train acc:%f, Test acc:%f"
          % (epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
    train_avg_acc.append(train_acc / len(train_data))
    train_avg_ls.append(train_loss / len(train_data))
    test_avg_acc.append(test_acc)
plt.ylim(0, 1) #设置y轴区间
plt.grid() #网格线
plt.plot(train_avg_acc)
plt.plot(train_avg_ls)
plt.plot(test_avg_acc,linestyle=':') # 虚线
plt.legend(['train acc','train loss','test acc'])
plt.show()

运行结果:
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MXNet是一个旨在提高效率和灵活性的深度学习框架。它是亚马逊选择的深度学习库之一,也被认为是最优秀的库之一。MXNet拥有类似于Theano和TensorFlow的数据流图,可以在多个GPU上进行配置,并提供了类似于Lasagne和Blocks的高级模型构建块。此外,MXNet还提供了对多种编程语言的支持,包括Python、R、Julia、C++、Scala、Matlab和Javascript。MXNet的目标是加速大规模深度神经网络的开发和部署,它提供了设备放置、多GPU训练、自动区分和优化的预定义图层等功能,以帮助开发人员充分利用GPU和云计算的能力。MXNet还具有计算和内存效率高的特点,可以在各种异构系统上运行,从移动设备到分布式GPU集群。\[1\]\[2\]\[3\] #### 引用[.reference_title] - *1* *3* [DL框架MXNet深度学习框架MXNet 的简介、安装、使用方法、应用案例之详细攻略](https://blog.csdn.net/qq_41185868/article/details/79153500)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [深度学习框架MxNet】的安装](https://blog.csdn.net/ctu_sue/article/details/127426528)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值