Pytorch 中 Dataset 和 DataLoader,以及 torchvision 的 datasets 完全理解

1、torch.utils.data.Dataset()

首先最基础的,是 torch.utils.data.Dataset()官方文档),它是 Pytorch 中表示数据集的抽象类,可以将其理解为如下:

class Dataset(object):

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError
  • __getitem__() 方法通过索引返回数据集中选定的样本
  • __len__() 方法返回数据集的总大小(实际上没定义在抽象类中,而是在 Sampler 中)

可见抽象类中的两个方法都是还没实现的,所以如果想实际使用 Dataset,就必须继承这个抽象类,创建一个子类,改写这两个方法,例如:

class CustomDataset(torch.utils.data.Dataset):

	# Basic Instantiation
	def __init__(self, ..., *args, **kwargs):
		...
	# Fetch an item from the Dataset
	def __getitem__(self, index):
		...	
	# Length of the Dataset
	def __len__(self):
		...

自定义数据集的具体例子可以看看这篇博客

2、torch.utils.data.Sampler()

有了数据集之后,就需要从中采样数据,这就是 torch.utils.data.Sampler()官方文档) 的作用,它是所有采样器的基类,可以将其理解如下:

class Sampler(object)

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError
  • __iter__() 方法用于迭代数据集元素索引

从官方实现的各种 Sampler 的子类源代码中可以看出,__iter__() 方法实际上就是用 Python 中的 iter()next()yield 等迭代器和生成器的方法(详见这篇博客),基于数据集产生一个迭代器,可以迭代得到数据集上的样本。

3、torch.utils.data.DataLoader()

最后就是 torch.utils.data.DataLoader()官方文档),它的作用就是:

Combines a dataset and a sampler, and provides an iterable over the given dataset

结合一个 Dataset 和一个 Sampler,然后返回一个该数据集上的可迭代对象。当然它还可以指定 Batch_size,以及支持多进程等等。

4、torchvision.datasets.ImageFolder()

首先介绍下 torchvision 包,它和 torch 一样都归属于 Pytorch 深度学习框架,torchvision 是由常用数据集、模型架构和用于计算机视觉的常见图像转换所组成的。

torchvision.datasets 模块(官方文档)既有官方提供的数据集,也有自定义数据集的类,它们都是 torch.utils.data.Dataset子类,因此可以直接输入到 torch.utils.data.DataLoader 中。

官方提供的数据集如:torchvision.datasets.MNIST()torchvision.datasets.FashionMNIST()torchvision.datasets.ImageNet() 等等;

自定义的数据集类有三个,最常用的是 torchvision.datasets.ImageFolder(),它继承自 torchvision.datasets.DatasetFolder(),后者又继承自 torchvision.datasets.VisionDataset(),VisionDataset 则是 torch.utils.data.Dataset 的子类。传入 ImageFolder 的 root 路径参数,里面的子文件夹对应类别名,然后类别名文件夹里面就存放有该类别的图片,如下:

 root
├── orange
│   ├── orange_image1.png
│   └── orange_image1.png
├── apple
│   └── apple_image1.png
│   └── apple_image2.png
│   └── apple_image3.png

5、例子 torchvision.datasets.FashionMNIST()

首先是导入库:

# imports 导入各种库
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

然后就可以直接使用官方提供的数据集了,其中有许多参数,如存放的路径、作为训练集还是测试集、如何数据增强等等:

train_set = torchvision.datasets.FashionMNIST( 
    root='./data' 
    ,train=True 
    ,download=True 
    ,transform=transforms.Compose([ 
        transforms.ToTensor() 
    ]) 
)

得到数据集之后,就可以输入到 DataLoader 中,它也有很多参数,如批量大小、是否打乱、使用 CPU 的进程数等等:

train_loader = torch.utils.data.DataLoader(
	train_set 
    ,batch_size=4
    ,shuffle=True 
)

此时,train_loader 就是一个可迭代对象,我们既可以先用 iter() 将其变成迭代器再用 next() 逐个批次进行迭代,例如:

在这里插入图片描述

next() 返回的一个批次大小为 4,包含图像数据和标签数据,对图像进行可视化如下:

在这里插入图片描述

也可以直接 for 循环进行迭代,如:

for batch_idx, samples in enumerate(train_loader):
      print(batch_idx, samples)

更具体的例子可以看看这篇博客

### 回答1: 如果你是在问 PyTorch 的数据集和数据加载器,那么我很愿意为您解答。 PyTorch 是一个开源深度学习框架,其包含了用于加载和预处理数据的工具。其最重要的两个组件是数据集 (Dataset) 和数据加载器 (DataLoader)。 数据集是一个 PyTorch 类,它定义了如何读取数据、如何访问数据以及如何将数据转换为张量。您可以使用内置的数据集类,例如 torchvision.datasets 的 ImageFolder,或者自定义数据集类。 数据加载器是一个 PyTorch 类,它可以从数据集读取数据并将其批量加载到内存。数据加载器可以进行并行加载,以提高加载速度,并且可以通过 shuffle 参数来随机打乱数据。 举个例子,如果您有一个图像数据集,您可以使用以下代码来创建数据集和数据加载器: ``` import torch import torchvision # 创建数据集 dataset = torchvision.datasets.ImageFolder(root='path/to/data', transform=transforms.ToTensor()) # 创建数据加载器 dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` 然后,您可以使用以下代码来读取数据: ``` for inputs, labels in dataloader: # 处理输入数据 ... ``` 希望对您有所帮助! ### 回答2: PyTorch是一种广泛使用的深度学习框架,具有易于使用的API和优秀的性能。其DatasetDataLoader是两个非常重要的类,它们可以帮助我们有效地加载和处理数据。 Dataset是一个抽象的概念,用于表示一组数据。我们可以继承它并重写其的方法,以实现对不同数据集的适配。在初始化时,我们需要传递一个数据集,比如说图片数据集,然后在DataLoader使用这个数据集,实现数据的准备和加载。在自定义Dataset时,我们需要定义__getitem__和__len__两个方法,分别用于返回数据集的某个数据和数据总数。 DataLoader是一个非常实用的工具,用于加载数据并把数据变成可迭代的对象,其包含了批量大小、数据是否随机等设置。我们可以设置num_workers参数,用多个进程来读取数据提高读取数据的速度。通过使用DataLoader,我们可以很方便地迭代整个数据集,可以按批次加载和处理数据。 当我们使用在线学习时,经常需要不断地读取数据并进行训练。在应用,我们会遇到许多不同的数据集,其可能包含不同的数据类型,比如图像、音频、文本等。使用DatasetDataLoader类,我们可以轻松处理这些数据,从而使我们的深度学习应用具有更广泛的适用性和扩展性。 总之,DatasetDataLoaderPyTorch非常重要的类,它们可以帮助我们非常方便地进行数据的处理和加载。无论你想要使用哪种数据集,它们都能够很好地适配。在实际应用,我们可以灵活地使用这两个类来加载和准备数据并进行训练,从而加快应用的速度并提高深度学习的精度。 ### 回答3: PyTorch是一个流行的深度学习框架,它提供了DatasetDataLoader这两个类来帮助我们更方便地处理数据。 Dataset可以看作是一个数据集,它定义了如何读取数据。官方提供了两种Dataset:TensorDataset和ImageFolder。TensorDataset是用来处理张量数据,而ImageFolder则是用来处理图像数据。如果我们需要使用其他类型的数据,我们可以通过重写Dataset的__getitem__和__len__方法来实现。 在实现Dataset之后,我们需要将数据读取到内存,在模型训练时提供给模型,这时我们就需要使用到DataLoader了。DataLoader可以看作是一个数据加载器,它会自动将Dataset的数据批量读取到内存,并且支持数据的分布式加载。 在使用DataLoader时我们可以设置很多参数,比如batch_size表示每个batch的大小,shuffle表示是否打乱数据顺序,num_workers表示使用多少线程读取数据等等。这些参数都可以帮助我们更好地利用硬件资源,提高训练速度和效率。 使用PyTorchDatasetDataLoader可以帮助我们更方便快捷地处理数据,并且让我们可以更专注于模型的设计和训练。但我们也要注意一些细节问题,比如数据读取是否正确、内存使用是否合理等等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值