pytorch识别手写数字(MNIST)

#此文章用于记录学习进度#

一.选择数据集

对于识别手写数字这个目的来说,我们的数据集肯定是很多张手写数字的图片,那么这里用到的是MNIST数据集,其中的图片格式为28*28的灰度图。

二.理解神经网络的构建过程

我们先来认识一下图片的存储形式,数据集中的28*28的灰度图是以28*28的矩阵存储的,而且因为是灰度图,这里是一个二维数组的形式存储这个矩阵,每个元素大小范围是[0,255]。

1.第0层 

构建前我们把这样一个矩阵reshape成一个一维数组,并且用0和1代表这个位置是否有笔迹(这里我们需要设置一个阈值,大于等于该阈值则认为有笔迹,小于则认为无笔迹)。由此,我们就构建了神经网络的第0层节点。

2.第1层

第1层节点由第0层节点计算得到,这里需要设置一个节点传播公式,类似于一个函数,这里我们假设该模型是一个线性关系的模型(然而大部分具体问题都是非线性的),我们选取y=a∑xn+b,我们把第0层的n个节点映射到第1层的一个节点。由此我们得到了28*28/n个节点组成的第1层节点。

3.正向传播

类似的,我们将第1层的节点向前传播,得到最后的输出层

最后的十个节点对应的是该图像为0~9中某个数字的概率。但在为给这十个节点赋上我们人类认知的概念前,这十个节点是没有任何意义的。那么如何让这十个节点转变为我们认知中的0~9的概率呢?那么就需要训练。

4.训练 

前三点举的例子都是对于一张图像来说,但我们MINST训练集有六万张图片,而且每张图片都带有各自的标签(即这张图片对应哪个数字),通过一次次的训练,让模型优化出最优的参数a和b(第二点提到的一次线性函数),这里的a和b也称为网络参数,由此,我们把这个模型的求解问题转化成了一个最优化问题。通过计算机对图像的一次次预测,得出预测的准确率,如果太低,则使用一些来调整参数(如梯度下降法,感兴趣可以了解一下,在数学建模中也会使用到)。

ps:这里使用的优化思想在机器学习中是非常常见的,如knn等等,而对于这些机器学习的算法参数的优化,我们也称为网格搜索。如sklearn.model_selection模块中的GridSearchCV类,这个类就是专门用于帮助我们系统地搜索最佳参数组合,从而优化模型性能。

三.运行代码

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


# 定义一个Net类,继承torch.nn.Module
class Net(torch.nn.Module):
    # 初始化函数
    def __init__(self):
        super().__init__()
        # 定义第一个全连接层,输入维度为28*28,输出维度为64
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        # 定义第二个全连接层,输入维度为64,输出维度为64
        self.fc2 = torch.nn.Linear(64, 64)
        # 定义第三个全连接层,输入维度为64,输出维度为64
        self.fc3 = torch.nn.Linear(64, 64)
        # 定义第四个全连接层,输入维度为64,输出维度为10
        self.fc4 = torch.nn.Linear(64, 10)

   def forward(self, x):
        # 计算第一个全连接层的输出
        x = torch.nn.functional.relu(self.fc1(x))
        # 计算第二个全连接层的输出
        x = torch.nn.functional.relu(self.fc2(x))
        # 计算第三个全连接层的输出
        x = torch.nn.functional.relu(self.fc3(x))
        # 计算第四个全连接层的输出,并计算softmax激活函数
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        # 返回输出
        return x


# 定义一个函数,根据参数is_train,返回一个DataLoader对象
def get_data_loader(is_train):
    # 定义一个转换,将图像转换为Tensor
    to_sensor = transforms.Compose([transforms.ToTensor()])
    # 创建一个MNIST数据集,训练集或测试集,并使用转换to_sensor
    dataset = MNIST(root="", train=is_train, transform=to_sensor, download=True)
    # 返回一个DataLoader对象,batch_size为15,shuffle为True
    return DataLoader(dataset, batch_size=15, shuffle=True)


# 定义一个函数,用于评估测试数据和网络
def evaluate(test_data, net):
    # 定义一个变量,用于记录预测正确的样本数
    n_correct = 0
    # 定义一个变量,用于记录总样本数
    n_total = 0
    # 使用torch.no_grad(),禁止梯度计算
    with torch.no_grad():
        # 遍历测试数据
        for (x, y) in test_data:
            # 使用网络对输入数据进行预测
            outputs = net.forward(x.view(-1, 28 * 28))
            # 遍历预测结果
            for i, outputs in enumerate(outputs):
                # 如果预测结果和标签一致,则预测正确
                if torch.argmax(outputs) == y[i]:
                    n_correct += 1
                # 总样本数加1
                n_total += 1
    # 返回预测正确的样本数占总样本数的比例
    return n_correct / n_total


def main():
    # 获取训练数据
    train_data = get_data_loader(True)
    # 获取测试数据
    test_data = get_data_loader(False)
    # 实例化一个神经网络
    net = Net()

    # 打印初始准确率
    print("initial accuracy:", evaluate(test_data, net))
    # 设置优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    # 训练2个epoch
    for epoch in range(2):
        # 遍历训练数据
        for (x, y) in train_data:
            # 梯度归零
            net.zero_grad()
            # 计算输出
            outputs = net.forward(x.view(-1, 28 * 28))
            # 计算损失
            loss = torch.nn.functional.nll_loss(outputs, y)
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()
        # 打印准确率
        print("epoch:", epoch, "accuracy:", evaluate(test_data, net))
    # 遍历测试数据
    for (n, (x, _)) in enumerate(test_data):
        # 只取前3个
        if n > 3:
            break
        # 计算预测结果
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
        # 绘制图片
        plt.figure(n)
        plt.imshow(x[0].view(28, 28), cmap="gray")
        plt.title("prediction: " + str(int(predict)))
    # 显示图片
    plt.show()


if __name__ == "__main__":
    main()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值