Pytorch学习系列-02-构建浅层神经网络实现手写数字识别(Mnist数据集)

  • Pytorch搭建简单神经网络完成手写数字识别,其中利用到的知识点如下:
    (1)用torch.nn包里面的函数搭建网络
    (2)Torchvision.transofrms来做数据预处理
    (3)DataLoader简单调用处理数据集
    (4)模型保存为pt文件与加载调用

下面看完整的代码:

"""
1.使用torch.nn包里面的搭建网络
2.Torchvision.transofrms来做数据处理
3.DataLoader简单调用处理数据集
4.模型保存为pt文件与加载调用
"""

import torch as t
from torch.utils.data import DataLoader
import torchvision as tv

#预处理数据方式
##ToTensor表示把灰度图像像素值从0-255转为0-1之间
##Normalize表示把输入的减去0.5,除以0.5
transform = tv.transforms.Compose([tv.transforms.ToTensor(),
                                   tv.transforms.Normalize((0.5,),(0.5,)),
                                   ])

#数据集(Mnist数据集,数字为0-9,大小为28*28的灰度图像)
##加载此数据集,这里pytorch已经写好调用接口,直接运行,没有此数据集的话,会自动下载
train_ts = tv.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
test_ts = tv.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
train_dl = DataLoader(train_ts,batch_size=32,shuffle=True,drop_last=False)
test_dl = DataLoader(test_ts,batch_size=64,shuffle=True,drop_last=False)

#网络结构
##输入层:784个神经单元,隐藏层:100个神经单元,输出层:10个神经单元
model = t.nn.Sequential(
    t.nn.Linear(784,100),
    t.nn.ReLU(),
    t.nn.Linear(100,10),
    t.nn.LogSoftmax(dim=1)
)


#定义损失函数与优化函数
loss_fn = t.nn.NLLLoss(reduction='mean')
optimizer = t.optim.Adam(model.parameters(),lr=1e-3)

#开启训练
for s in range(5):
    print ('run in step :{}'.format(s))
    for i,(x_train,y_train) in enumerate(train_dl):
        x_train = x_train.view(x_train.shape[0],-1)
        y_pred = model(x_train)
        train_loss = loss_fn(y_pred,y_train)
        if (i+1)%100 == 0:
            print (i+1,train_loss.item())
        model.zero_grad()
        train_loss.backward()
        optimizer.step()

#测试模型准确性
total = 0
correct_count = 0
for test_images,test_labels in test_dl:
    for i in range(len(test_labels)):
        image = test_images[i].view(1,784)
        with t.no_grad():
            pred_labels = model(image)
        plabels = t.exp(pred_labels)
        probs = list(plabels.numpy()[0])
        pred_label = probs.index(max(probs))
        true_label = test_labels.numpy()[i]
        if pred_label == true_label:
            correct_count += 1
        total += 1


#打印准确率,保存模型
print('total acc:{}'.format(correct_count/total))
t.save(model,'./nn_mnist_model.pt')

运行结果:

1100 0.11727560311555862
1200 0.08096164464950562
1300 0.23481614887714386
1400 0.017000077292323112
1500 0.03314385190606117
1600 0.10333727300167084
1700 0.04443451762199402
1800 0.024772992357611656
total acc:0.9642

参考链接:https://mp.weixin.qq.com/s?__biz=MzA4MDExMDEyMw==&mid=2247488419&idx=1&sn=0cbeb05b9d5e33a8be8a2ce5b2aa4e97&chksm=9fa864e7a8dfedf18205fd6d1266063771ec999dd7a72403b7c531b91331e988db74c32d1af1&scene=178#rd

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值