经典数据集-手写数字识别pytorch

pytorch经典数据集-手写数字识别

一、什么是MNIST?

MNIST是计算机视觉领域中最为基础的一个数据集,也是很多人第一个神经网络模型

MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集

MNIST中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。MNIST使用一个长度为10的one-hot向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。

二、详细代码介绍

MNIST手写数字识别的主要目为:训练出一个模型,让这个模型能够对手写数字图片进行分类。

首先先搞清楚步骤流程,然后才开始构建网络结构开始训练模型

导入要用到的库

utils是外部文件,自己定义的几个函数,详细代码已放文章末尾

#导入需要的各种库
import torch
#神经网络
from torch import nn
#function神经网络中常见的函数
from torch.nn import functional as F
#梯度下降优化包
from torch import optim
#图形视觉包
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

加载数据集

#1 加载数据集
#load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)#shuffle打乱

#预览训练集数据
x, y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())

#画图,图片识别,识别结果
plot_image(x,y,'image_sample')

用Net模型创建三层的网络结构+加一层relu激活函数层


#2 创建网络
#制作三层线性网络层 + relu函数 网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

    #三层线性 xw +b
    #第一层 28*28 =》打平成一个向量 输出是中间层,一般取2^n,逐步减小
        #Linear(输入,输出)
        self.fc1 = nn.Linear(28*28,256)
        #第二层 上一层输出是这一层的输入
        self.fc2 = nn.Linear(256,64)
        #第三层 是最终的输出=== 分类数有关
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        # x[512,1,28,28] 输入层结构:512张灰度图片,28*28
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #x = F.relu(self.fc3(x))
        #一般来说,最后一层激活函数可加可不加
        x = self.fc3(x)
        return x

网络训练

#3 网络训练
#迭代的次数,对数据集迭代3次
for epoch in range(3):
    #每次迭代,对数据集每512张做训练
    for batch_idx, (x,y) in enumerate(train_loader):
        # x[512,1,28,28] 28*28===1*784 打平矩阵,维度转换
        x = x.view(x.size(0),28*28)#一维 1*784
        # 放入网络训练
        #out:[512,10]
        out = net(x)
        #label用onthot编码转化成向量
        y_onehot = one_hot(y)
        #计算loss 欧式距离
        loss = F.mse_loss(out,y_onehot)
        #梯度下降
        #梯度清零
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #更新梯度 w' = w - lr * grad
        optimizer.step()
        #此时退出循环,得到了最好的结果【w1,w2,w3,b1,b2,b3】
        if batch_idx % 10 == 0:
            #每10次打印loss
            print(epoch,batch_idx,loss.item())

验证测试

#4 验证
total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)#[512,x]
    pred = out.argmax(dim =1)#dim维度
    #pred =? 相等的数量有几张 eq()相等记为1,不相等记为0
    correct =pred.eq(y).sum().float().item()
    total_correct+=correct

total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:",acc)

x,y =next(iter(test_loader))
out =net(x.view(x.size(0),28*28))
pred = out.argmax(dim =1)
plot_image(x,pred,'test')

全部代码

如果要编写成一个脚本的话,把下面函数部分复制到同一个py文件就行了,就不用多创建一个py文件,为了让代码更好维护与调试建议分开它

import torch
from matplotlib import pyplot as plt#绘图库
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

mnist.py文件分享地址需要自取

链接: https://pan.baidu.com/s/1psjbAH5wxtaAyQpRXArr6g?pwd=y88a 提取码: y88a 复制这段内容后打开百度网盘手机App,操作更方便哦

如有错误之处请指正

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值