1、思路和流程分析
流程:
- 1.准备数据,这些需要准备DataLoader
- 2.构建模型,这里可以使用torch构造一个深层的神经网络
- 3.模型的训练
- 4.模型的保存,保存模型,后续持续使用
- 5.模型的评估,使用测试集,观察模型的好坏
2、准备训练集和测试集
准备数据集的方法上一节已经讲过数据加载Dataset和DataLoader的使用,在这里我们使用pytorch自带的mnist数据集来做,也就是说,我们不需要再去写自己的dataset类了,pytorch已经封装好了,我们只需要调用即可。API如下所示:
mnist = MNIST(path, train=True, download=True) # 是个dataset的实例
参数:
path
:表示保存数据的路径;
train
:训练集的数据;
download
:是否下载,因为第一次使用要从官网下载,下载了在之后,就可以设置为False了。
后面还有参数之后再介绍。
但是,调用MNIST返回的结果中图形数据是一个mage对象需要对其进行处理。
为了进行数据的处理,接下来需要学习torchvision.transform
的方法:
2.1 torchvision.transform的图形数据处理方法
(1) torchvision.transform.ToTensor
作用:把一个取值范围是[0,255]的PIL.Image
或者shape
为(H,W,C)的numpy.ndarray
,转换成形状为[C,H,W] ,取值范围是[0,1. 0]的torch.FloatTensor
。
其中(H,W,C)意思为(高,宽,通道数),黑白图片的通道数只有1,其中每个像素点的取值为
[0,255],彩色图片的通道数为(R,G,B)每个通道的每个像素点的取值为[0,255],三个通道的颜色相互叠加,形成了各种颜色。
示例如下:
from torchvision.datasets import MNIST
from torchvision import transforms
mnist = MNIST("./data", train=True, download=False) # 是个dataset的实例
# mnist[0][0].show() # 由于mnist是一个实例化对象,则可以用[]方式取每一条数据,mnist[0]表示第一条数据,是个元组(image,label)
print(mnist[0]) # 第1条数据,是个元组(image,label)
ret = transforms.ToTensor()(mnist[0][0]) # 将mage对象转换成张量,[28,28,1]->[1,28,28]
print(ret.size())
运行结果:
(<PIL.Image.Image image mode=L size=28x28 at 0x211EE644198>, 5)
torch.Size([1, 28, 28])
注意:
transforms.ToTensor
对象有个__call__
方法,所以可以对其示例能够传入数据获取结果。
(2)torchvision.transforms.Normalize(mean, std)
作用:标准化张量。
参数:
给定均值: mean
,注意: shape和图片的通道数相同(指的是每个通道的均值);
方差: std
, 注意:和图片的通道数相同(指的是每个通道的方差),将会把Tensor规范化处理。
即: Normalized_ image=(image-mean)/std
.
例如:
from torchvision.datasets import MNIST
from torchvision import transforms
mnist = MNIST("./data", train=True, download=False) # 是个dataset的实例
# mnist[0][0].show() # 由于mnist是一个实例化对象,则可以用[]方式取每一条数据,mnist[0]表示第一条数据,是个元组(image,label)
print(mnist[0]) # 第1条数据,是个元组(image,label)
ret = transforms.ToTensor()(mnist[0][0])
print(ret.size())
norm_img = transforms.Normalize(mean=[0.1307], std=[0.3081])(ret) # 进行规范化处理
print(norm_img)
运行结果:
(3)torchvision.transforms.Compose(transforms)
作用:将多个transforms组合起来使用。
例如:
transforms.Compose([transforms.ToTensor()(), # 先转化为Tensor
transforms.Normalize(mean, std)]) # 再进行正则化
2.2 准备MNIST数据集的Dataset和DataLoader
# 1、准备数据集