李沐基于Pytorch的深度学习笔记(11.5)-MNIST的数字识别实现(含代码)

本次博文参考的文章如下:

用PyTorch实现MNIST手写数字识别(非常详细)_小锋学长生活大爆炸-CSDN博客_pytorch手写数字识别

我们对这篇文章进行了一个更加详细的讲解

首先我们来设置相关的包和库以及后面会用到的数据

import torch
import numpy as np
import pandas as pd
import random
import matplotlib
import matplotlib.pyplot as plt
import os
import torchvision
from torchvision import transforms
from sklearn.datasets import load_boston
import joblib
from sklearn.linear_model import LinearRegression, SGDRegressor, Ridge
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import DataLoader
from torch import nn
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)#因为每次设置的随机数种子,其实打印出来的数值是相同的,那么使用manual_seed来固定每次的种子固定值,这样在排序随机数时候,当内部选择不一样时,就可以数值不同了

t.manual_seed(1)_fnzwj的博客-CSDN博客_manual_seed

关于manual_seed的一个实际用法,可以参考上卖弄那篇博文。

对于可重复的实验,我们必须为任何使用随机数产生的东西设置随机种子——如numpy和random! 

0现在我们还需要数据集的dataloader。这就是TorchVision发挥作用的地方。它让我们用一种方便的方式来加载MNIST数据集。我们将使用batch_size=64进行训练,并使用size=1000对这个数据集进行测试。下面的Normalize()转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差,这里我们将它们作为给定值。

TorchVision提供了许多方便的转换,比如裁剪或标准化。

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

首先需要看一下torch.utils.data.DataLoader什么意思:

PyTorch 中的数据类型 torch.utils.data.DataLoader_rogerfang的博客-CSDN博客_dataloader数据类型

关于torchvision.datasets.MNIST数据集的具体用法,参考下面的

​​​​​​[ PyTorch ] torch.utils.data.DataLoader 中文使用手册_江南蜡笔小新-CSDN博客_pytorch torch.utils.data

transforms.Compose关于这个函数可以参考下下面的解释

torchvision中transforms.Compose学习理解_thinklis的博客-CSDN博客_torchvision.transforms.compose

理解transforms.ToTensor()函数_LJC的博客-CSDN博客_transforms.totensor

关于transforms.Normalize()函数_开飞机的小毛驴儿-CSDN博客_transforms.normalize

进行单独的运行,对数据集进行下载

 除了数据集和批处理大小之外,PyTorch的DataLoader还包含一些有趣的选项。例如,我们可以使用num_workers > 1来使用子进程异步加载数据,或者使用固定RAM(通过pin_memory)来加速RAM到GPU的传输。但是因为这些在我们使用GPU时很重要,我们可以在这里省略它们。

    现在让我们看一些例子。我们将为此使用test_loader。

    让我们看看一批测试数据由什么组成。

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)

 这意味着我们有1000个例子的28x28像素的灰度(即没有rgb通道)。

    我们可以使用matplotlib来绘制其中的一些

import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()

后续构建网络,李沐老师还没教到,就不做过多讲解了

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ReedswayYuH.C

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值