PyTorch实战手写数字识别

本文详细介绍了使用PyTorch实现MNIST手写数字识别的过程,包括数据下载、数据变换、模型搭建、损失函数和优化器选择、模型训练及验证。在数据预处理阶段,解决了因图片通道数不匹配导致的RuntimeError,最终成功训练并验证了模型的准确性。
摘要由CSDN通过智能技术生成

目录

1. 导包

import torch
from torchvision import datasets, transforms, utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

2.数据下载

数据下载过于缓慢,可在网上搜索MNIST数据包(共有四个文件),将其存放到相应的文件目录下(即文件结构图中的raw文件夹),再运行程序,进行数据处理。

  • 数据
    手写数字数据集,共有60000张训练图片和10000张测试图片。
    官方网站:http://yann.lecun.com/exdb/mnist/
    百度云网盘
    链接:https://pan.baidu.com/s/1fp3E279lOOwcx_Zq62-iaQ
    提取码:p3tb
  • MNIST数据包中包含的文件:
    在这里插入图片描述
  • 最终得到的文件结构:
    在这里插入图片描述
  • 代码
data_train = datasets.MNIST(root="./data/",
                           transform=transform,
                           train=True,
                            download=True)

data_test = datasets.MNIST(root="./data/",
                          transform=transform,
                          train=False)

3. 数据变换

程序下载代码中的transform=transform这一句,等式右端的transform是在《深度学习之PyTorch实战计算机视觉》6.4.2部分定义的,所以在编程时,要将transform的定义放在数据下载代码的前面。

  • 代码
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

4. 数据装载

  • 代码
data_loader_train = torch.utils.data.DataLoader(dataset=data_train, 
                                               batch_size=64,
                                               shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                              batch_size=64,
                                              shuffle=True)

4. 数据预览

  • 代码
images, labels = next(iter(dat
  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值