PyTorch 从入门到放弃 —— 加载数据 审核中

PyTorch 有两种基础数据类型: torch.utils.data.DataLoader 和 torch.utils.data.DatasetDataset,它们存储着样本和对应的标记。 Dataset是样本数据集,DataLoader对Dataset进行封装,方便加载、遍历和分批等。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

 

  

PyTorch 提供了不同用途的数据集,比如: TorchTextTorchVision, and TorchAudio. 在本教程中,我们使用TorchVision。

torchvision.datasets 模块包含了各种视觉数据集, 比如 CIFAR, COCO (完整列表)。 本教程我们使用FashionMNIST数据集。 每个视觉数据集包含2个参数:transform 和 target_transform,可以分别用来修改样本和标记。

# 从开放机构下载训练数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
 
# 下载测试数据集
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
输出:

输出:Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 65536/26421880 [00:00<01:12, 365718.31it/s]
  1%|          | 229376/26421880 [00:00<00:38, 685682.68it/s]
  3%|3         | 884736/26421880 [00:00<00:10, 2498938.52it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4141475.37it/s]
 19%|#8        | 4915200/26421880 [00:00<00:01, 10854978.12it/s]
 26%|##5       | 6782976/26421880 [00:00<00:01, 11037400.65it/s]
 37%|###7      | 9797632/26421880 [00:01<00:01, 15568756.79it/s]
 44%|####4     | 11730944/26421880 [00:01<00:01, 14184748.16it/s]
 55%|#####5    | 14647296/26421880 [00:01<00:00, 17510568.70it/s]
 63%|######3   | 16777216/26421880 [00:01<00:00, 15834704.91it/s]
 75%|#######4  | 19693568/26421880 [00:01<00:00, 18759775.35it/s]
 83%|########2 | 21889024/26421880 [00:01<00:00, 16780435.96it/s]
 94%|#########3| 24772608/26421880 [00:01<00:00, 19391805.01it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 13914460.04it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 326673.50it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 65536/4422102 [00:00<00:12, 362354.20it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 684627.79it/s]
 21%|##        | 917504/4422102 [00:00<00:01, 2626211.85it/s]
 44%|####3     | 1933312/4422102 [00:00<00:00, 4103892.12it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6109664.51it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 61868988.52it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Dataset作为参数传递给DataLoader。这样就可以把数据集封装起来,实现自动分批,取样,打乱和多处理器协同加载。在这里,我们定义每批大小为65,这样一来,分批遍历dataloader的时候,就能在循环中每次取到64组特征和标记。

batch_size = 64
 
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
 
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
输出:

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

想了解更多请移步 从TyTorch加载数据

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值