pytorch 预测手写体数字_PyTorch学习笔记7 - 使用PyTorch完成手写体数字识别

这篇PyTorch学习笔记介绍了如何利用LeNet5神经网络模型解决手写体数字识别问题。通过导入MNIST数据集,搭建包含卷积和池化层的LeNet5网络,训练并可视化损失曲线,最终获得超过98%的测试准确率。
摘要由CSDN通过智能技术生成

本篇笔记是PyTorch学习的最后一篇,使用PyTorch搭建神经网络,完成手写体数字识别问题。

1. 导入数据并可视化

import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

from torchvision import datasets, transforms

#定义batch size

batch_size = 64

#下载MNIST数据集

train_dataset = datasets.MNIST(root='./data/',

train=True,

transform=transforms.ToTensor(),

download=True)

test_dataset = datasets.MNIST(root='./data/',

train=False,

transform=transforms.ToTensor())

#将下载的MNIST数据导入到dataloader中

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,

batch_size=batch_size,

shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,

batch_size=batch_size,

s

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
为了进行PyTorch手写数字识别预测类别,我们可以使用线性回归模型。在这个模型中,我们首先需要加载手写数字识别数据集,并将数据集分为训练集和测试集。接下来,我们可以定义一个网络结构,该网络结构包含一个线性层和一个softmax层。然后,我们使用训练集对模型进行训练,并使用测试集对模型进行评估。 在评估过程中,我们通过模型运行测试集中的每个图像,并将模型输出的数字作为预测结果。然后,我们计算预测结果正确的数量,并将其除以测试集的总数量,得到预测的准确率。 下面是一个示例代码,展示了如何使用PyTorch进行手写数字识别预测类别: ```python import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader from tqdm import tqdm # 加载手写数字识别数据集 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False) # 定义网络结构 W = torch.randn(784, 10) # 权重矩阵 b = torch.randn(10) # 偏置向量 #评估模型 correct = 0 total = len(mnist_test) with torch.no_grad(): # 遍历测试集的小批量数据 for images, labels in tqdm(test_loader): # 前向传播 x = images.view(-1, 28*28) y = torch.matmul(x, W) + b predictions = torch.argmax(y, dim=1) # 统计预测结果正确的数量 correct += torch.sum((predictions == labels).float()) # 计算准确率 accuracy = correct / total print('Test accuracy: {}'.format(accuracy)) ``` 在上述代码中,我们首先导入所需的库,并定义了一个数据转换流程,用于将数据转换为张量并进行归一化处理。然后,我们加载手写数字识别数据集,并将其分批次加载到数据加载器中。接下来,我们定义了网络模型的参数W和b。在评估过程中,我们使用torch.no_grad()来关闭梯度计算,加快评估速度。最后,我们计算预测准确率并输出结果。 请注意,上述代码只是一个示例,实际情况中可能需要根据具体情况进行调整和修改。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [pytorch-简单回归问题-手写数字识别](https://blog.csdn.net/qq_44653420/article/details/130984978)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [基于 PyTorch手写数字分类](https://blog.csdn.net/weixin_38739735/article/details/117971150)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值