Pytorch100例-第1天:实现mnist手写数字识别第1天

本文为 🔗365天深度学习训练营 内部限免文章
参考本文所写记录性文章,请在文章开头注明以下内容,复制粘贴即可

⏲往期文章:

  • 难度:新手入门⭐
  • 语言:Python3、Pytorch

🍺 要求:

  1. 了解Pytorch,并使用Pytorch构建一个深度学习程序
  2. 了解什么是深度学习

🍻拔高(可选)

  1. 学习文中提到的函数方法

一、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

!pip install torchvision==0.1.8 
Collecting torchvision==0.1.8
  Downloading torchvision-0.1.8-py2.py3-none-any.whl (37 kB)
Requirement already satisfied: numpy in c:\users\jie liang\anaconda3\lib\site-packages (from torchvision==0.1.8) (1.21.5)
Requirement already satisfied: six in c:\users\jie liang\anaconda3\lib\site-packages (from torchvision==0.1.8) (1.16.0)
Requirement already satisfied: torch in c:\users\jie liang\anaconda3\lib\site-packages (from torchvision==0.1.8) (1.13.0)
Requirement already satisfied: pillow in c:\users\jie liang\anaconda3\lib\site-packages (from torchvision==0.1.8) (9.2.0)
Requirement already satisfied: typing-extensions in c:\users\jie liang\anaconda3\lib\site-packages (from torch->torchvision==0.1.8) (4.3.0)
Installing collected packages: torchvision
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.14.0
    Uninstalling torchvision-0.14.0:
      Successfully uninstalled torchvision-0.14.0
Successfully installed torchvision-0.1.8
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device
device(type='cpu')

2. 导入数据

使用dataset下载MNIST数据集,并划分好训练集与测试集

使用dataloader加载数据,并设置好基本的batch_size

torchvision.datasets.MNIST详解

torchvision.datasets是Pytorch自带的一个数据库,我们可以通过代码在线下载数据,这里使用的是torchvision.datasets中的MNIST数据集。

函数原型:

torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  • root (string) :数据地址
  • train (string) :True = 训练集,False = 测试集
  • download (bool,optional) : 如果为True,从互联网上下载数据集,并把数据集放在root目录下。
  • transform (callable, optional ):这里的参数选择一个你想要的数据转化函数,直接完成数据转化
  • target_transform (callable,optional) :接受目标并对其进行转换的函数/转换。
train_ds = torchvision.datasets.MNIST('data', 
                                      train=True, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)

test_ds  = torchvision.datasets.MNIST('data', 
                                      train=False, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

📌请在这里补充关于torch.utils.data.DataLoader的介绍(建议参照torchvision.datasets.MNIST详解)

⭐ torch.utils.data.DataLoader详解

torch.utils.data.DataLoader是Pytorch自带的一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集

函数原型:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device=‘’)

参数说明:
dataset(string) :加载的数据集

batch_size (int,optional) :每批加载的样本大小(默认值:1)

shuffle(bool,optional) : 如果为True,每个epoch重新排列数据。

sampler (Sampler or iterable, optional) : 定义从数据集中抽取样本的策略。 可以是任何实现了 len 的 Iterable。 如果指定,则不得指定 shuffle 。

batch_sampler (Sampler or iterable, optional) : 类似于sampler,但一次返回一批索引。与 batch_size、shuf

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值