本文代码是**唐进民的《深度学习之Pytorch实战计算机视觉》**中的6.4节“实战手写数字识别”
我把书中代码敲到pycharm上时,可能由于版本之间的差异,出现了不同的问题。然后对几处代码进行了修改,使其能正常运行,并得出结果。
写本文的目的,就是希望自己通过学习这个案例,能对pytorch和cnn有一个基础的理解。
代码框架
1.导入必要的包
import torch
from torchvision import datasets,transforms
import torchvision
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
2.数据类型转换
将图片型数据转换成Tensor数据类型
transform=transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])])
3.数据集(训练集和测试集)的下载
data_train=datasets.MNIST(root="./data",
transform = transform,
train=True,
download=True
)
data_test=datasets.MNIST(root="./data",
transform = transform,
train=False)
torchvision.datasets再加上需要下载的数据集名称就可以下载数据集。
root 用于指定数据集在下载之后的路径,这里选择存放在根目录下的data文件夹下。
train 用于指定在数据集下载完成之后需要载入哪部分数据,如果train=True,则表示,载入的是该数据集的训练集部分,反之,则是该数据集的测试集部分
4.数据装载
数据的载入可以认为是对图片的处理,处理完这些照片,要将他们打包好送给我们的模型进行训练,然后数据的装载就是打包图片的过程。
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
batch_size=64,
shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
batch_size=64,
shuffle=True)
batch_size用来确认每个包的大小,这里等于64就是在每个包里有64张照片的意思。
shuffle是来确认要不要在装载过程中打乱图片的顺序,为True则表明要打乱顺序。
5.数据预览
images,labels=next(iter(data_loader_train)) #获取一个批次的图片数据和对应图片的标签
img=torchvision.utils.make_grid(images) # 将一个批次的照片构造成网格模式
img=img.numpy().transpose(1,2,0)
std=[0.5,0.5,0.5]
mean=[0.5,0.5,0.5]
img=img*std+mean
# print([labels[i] for i in range(64)])
# plt.imshow(img)
# plt.show()
# 让tensor数据,按图片中数字显示出来
for i in range(64):
print(labels[i], end=" ")
i += 1
if i%8 is 0:
print(end='\n')
plt.imshow(img)
plt.show()
结果展示:
tensor(6) tensor(3) tensor(8) tensor(9) tensor(4) tensor(9) tensor(4) tensor(5)
tensor(9) tensor(0) tensor(4) tensor(7) tensor(0) tensor(3) tensor(8) tensor(7)
tensor(5) tensor(6) tensor(1) tensor(3) tensor(5) tensor(0) tensor(0) tensor(0)
tensor(1) tensor(3) tensor(4) tensor(2) tensor(7) tensor(3) tensor(3) tensor(9)
tensor(2) tensor(9) tensor(2) tensor(5) tensor(8) tensor(5) tensor(0) tensor(0)
tensor(6) tensor(1) tensor(2) tensor(