Pytorch之实战手写数字识别

本文基于《深度学习之Pytorch实战计算机视觉》的6.4节,介绍了如何使用Pytorch进行手写数字识别。作者在实现书中代码时遇到了版本差异问题并进行了修正,确保代码能正常运行并得出结果。内容包括数据类型转换、数据集下载、数据装载、CNN模型搭建、损失函数和优化、模型训练及预测。
摘要由CSDN通过智能技术生成

本文代码是**唐进民的《深度学习之Pytorch实战计算机视觉》**中的6.4节“实战手写数字识别”
我把书中代码敲到pycharm上时,可能由于版本之间的差异,出现了不同的问题。然后对几处代码进行了修改,使其能正常运行,并得出结果。
写本文的目的,就是希望自己通过学习这个案例,能对pytorch和cnn有一个基础的理解。

1.导入必要的包

import torch
from torchvision import datasets,transforms
import torchvision
from torch.autograd import  Variable
import numpy as np
import matplotlib.pyplot as plt

2.数据类型转换

将图片型数据转换成Tensor数据类型

transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Lambda(lambda x: x.repeat(3,1,1)),
                              transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])])

3.数据集(训练集和测试集)的下载

data_train=datasets.MNIST(root="./data", 
                          transform = transform,
                          train=True,
                          download=True
                          )
data_test=datasets.MNIST(root="./data",
                         transform = transform,
                         train=False)

torchvision.datasets再加上需要下载的数据集名称就可以下载数据集。
root 用于指定数据集在下载之后的路径,这里选择存放在根目录下的data文件夹下。
train 用于指定在数据集下载完成之后需要载入哪部分数据,如果train=True,则表示,载入的是该数据集的训练集部分,反之,则是该数据集的测试集部分

4.数据装载

数据的载入可以认为是对图片的处理,处理完这些照片,要将他们打包好送给我们的模型进行训练,然后数据的装载就是打包图片的过程。

data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
                                              batch_size=64,
                                              shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
                                             batch_size=64,
                                             shuffle=True)

batch_size用来确认每个包的大小,这里等于64就是在每个包里有64张照片的意思。
shuffle是来确认要不要在装载过程中打乱图片的顺序,为True则表明要打乱顺序。

5.数据预览

images,labels=next(iter(data_loader_train))  #获取一个批次的图片数据和对应图片的标签
img=torchvision.utils.make_grid(images)  # 将一个批次的照片构造成网格模式

img=img.numpy().transpose(1,2,0)

std=[0.5,0.5,0.5]
mean=[0.5,0.5,0.5]

img=img*std+mean

# print([labels[i] for i in range(64)])
# plt.imshow(img)
# plt.show()


# 让tensor数据,按图片中数字显示出来
for i in range(64):
    print(labels[i], end=" ")
    i += 1
    if i%8 is 0:
        print(end='\n')
plt.imshow(img)
plt.show()

结果展示:

tensor(6) tensor(3) tensor(8) tensor(9) tensor(4) tensor(9) tensor(4) tensor(5) 
tensor(9) tensor(0) tensor(4) tensor(7) tensor(0) tensor(3) tensor(8) tensor(7) 
tensor(5) tensor(6) tensor(1) tensor(3) tensor(5) tensor(0) tensor(0) tensor(0) 
tensor(1) tensor(3) tensor(4) tensor(2) tensor(7) tensor(3) tensor(3) tensor(9) 
tensor(2) tensor(9) tensor(2) tensor(5) tensor(8) tensor(5) tensor(0) tensor(0) 
tensor(6) tensor(1) tensor(2) tensor(
  • 9
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值