PyTorch学习笔记——(7)使用pytorch实现手写数字识别,可以很好的练习pytorch

本文是PyTorch学习笔记,通过实现手写数字识别来讲解数据预处理、模型构建、训练和评估过程。使用了MNIST数据集,介绍了torchvision.transform的ToTensor、Normalize和Compose方法,以及模型的保存和加载。
摘要由CSDN通过智能技术生成

上一小节.
GitHub地址.

1、思路和流程分析

流程:

  • 1.准备数据,这些需要准备DataLoader
  • 2.构建模型,这里可以使用torch构造一个深层的神经网络
  • 3.模型的训练
  • 4.模型的保存,保存模型,后续持续使用
  • 5.模型的评估,使用测试集,观察模型的好坏

2、准备训练集和测试集

准备数据集的方法上一节已经讲过数据加载Dataset和DataLoader的使用,在这里我们使用pytorch自带的mnist数据集来做,也就是说,我们不需要再去写自己的dataset类了,pytorch已经封装好了,我们只需要调用即可。API如下所示:

mnist = MNIST(path, train=True, download=True) # 是个dataset的实例

参数:
path:表示保存数据的路径;
train:训练集的数据;
download:是否下载,因为第一次使用要从官网下载,下载了在之后,就可以设置为False了。
后面还有参数之后再介绍。

但是,调用MNIST返回的结果中图形数据是一个mage对象需要对其进行处理。

为了进行数据的处理,接下来需要学习torchvision.transform的方法:

2.1 torchvision.transform的图形数据处理方法

(1) torchvision.transform.ToTensor

作用:把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W] ,取值范围是[0,1. 0]的torch.FloatTensor

其中(H,W,C)意思为(高,宽,通道数),黑白图片的通道数只有1,其中每个像素点的取值为
[0,255],彩色图片的通道数为(R,G,B)每个通道的每个像素点的取值为[0,255],三个通道的颜色相互叠加,形成了各种颜色。

示例如下:

from torchvision.datasets import MNIST
from torchvision import transforms

mnist = MNIST("./data", train=True, download=False) # 是个dataset的实例

# mnist[0][0].show() # 由于mnist是一个实例化对象,则可以用[]方式取每一条数据,mnist[0]表示第一条数据,是个元组(image,label)

print(mnist[0]) # 第1条数据,是个元组(image,label)
ret = transforms.ToTensor()(mnist[0][0]) # 将mage对象转换成张量,[28,28,1]->[1,28,28]
print(ret.size())

运行结果:

(<PIL.Image.Image image mode=L size=28x28 at 0x211EE644198>, 5)
torch.Size([1, 28, 28])

注意:
transforms.ToTensor对象有个__call__方法,所以可以对其示例能够传入数据获取结果。

(2)torchvision.transforms.Normalize(mean, std)

作用:标准化张量
参数:
给定均值: mean注意: shape和图片的通道数相同(指的是每个通道的均值);
方差: std, 注意:和图片的通道数相同(指的是每个通道的方差),将会把Tensor规范化处理
即: Normalized_ image=(image-mean)/std.

例如:

from torchvision.datasets import MNIST
from torchvision import transforms

mnist = MNIST("./data", train=True, download=False) # 是个dataset的实例

# mnist[0][0].show() # 由于mnist是一个实例化对象,则可以用[]方式取每一条数据,mnist[0]表示第一条数据,是个元组(image,label)

print(mnist[0]) # 第1条数据,是个元组(image,label)
ret = transforms.ToTensor()(mnist[0][0])
print(ret.size())

norm_img = transforms.Normalize(mean=[0.1307], std=[0.3081])(ret)  # 进行规范化处理
print(norm_img)

运行结果:

在这里插入图片描述

(3)torchvision.transforms.Compose(transforms)

作用:将多个transforms组合起来使用

例如:

transforms.Compose([transforms.ToTensor()(), # 先转化为Tensor
                   transforms.Normalize(mean, std)]) # 再进行正则化

2.2 准备MNIST数据集的Dataset和DataLoader

# 1、准备数据集
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值