Pytorch(三):Dataset和Dataloader的理解

目录

1.可迭代对象,迭代器

2.数据集遍历的一般化流程

3.Dataset

4.TensorDataset

5.Dataloader


1.可迭代对象,迭代器

首先,我们要明白python中的两个概念:可迭代对象,迭代器。

  • 可迭代对象:
  1. 实现了__iter__方法,该方法返回一个迭代器对象。
  • 迭代器:

  1. 一个带状态的对象,内部持有一个状态,该状态用于记录当前迭代所在的位置,以方便下次迭代的时候获取正确的元素。

  2. 迭代器含有__iter__和__next__方法。当调用__iter__返回迭代器自身,当调用next()方法的时候,返回容器中的下一个值。

  3. 迭代器就像一个懒加载的工厂,等到有人需要的时候才给它生成值返回,没调用的时候就处于休眠状态等待下一次调用。

  • iter()函数:

1. 用法一:iter(callable, sentinel)

不停的调用callable,直至其的返回值等于sentinel。其中的callable可以是函数,方法或实现了__call__方法的实例。

2. 用法二:iter(collection)

1)iter()直接调用可迭代对象的__iter__(),并把__iter__()的返回结果作为自己的返回值,故该用法常被称为“创建迭代器”。

2)iter函数可以显示调用,或当执行“for i in obj:”,Python解释器会在第一次迭代时自动调用iter(obj),之后的迭代会调用迭代器的next方法,for语句会自动处理最后抛出的StopIteration异常。

3)但iter函数获取不到 __iter__方法时,还会调用 __getitem__方法,参数是从0开始能获取值就是可迭代的。

2.数据集遍历的一般化流程

for i, data in enumerate(dataLoader):

enumerate(dataloader )会调用dataloader 的__iter__()方法, 产生了一个DataLoaderIter(迭代器),这里判断使用单进程还是多进程调用DataLoaderIter __next__()方法来得到batch data。 在__next__()方法方法中使用_next_index()方法调用sampler(采样器)获得索引,接着通过dataset_fetcher的fetch()方法根据index(索引)调用dataset的__getitem__()方法, 然后用collate_fn来把它们打包成batch。当数据读完后, __next__()抛出一个StopIteration异常, for循环结束, dataloader 失效.

3.Dataset

torch.utils.data.Dataset是代表这一数据的抽象类(也就是基类)。我们可以通过继承重写这个抽象类实现自己的数据类,只需要定义__len____getitem__这个两个函数

如果在类中定义了__getitem__()方法,那么实例对象(假设为P)就可以这样P[key]取值。当实例对象做P[key]操作时,就会调用类中的__getitem__()方法。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
  • 通过复写 __getitem__ 方法可以通过索引index来访问数据,能够同时返回数据对应的标签(label),这里的数据和标签都为tensor类型 
  • 通过复写 __len__ 方法来获取数据的个数。 

 比如:

class MyDataset(Dataset): 
    """ my dataset."""
    
    # Initialize your data, download, etc.
    def __init__(self):
        # 读取csv文件中的数据
        xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]
        # 除去最后一列为数据位,存在x_data中
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        # 最后一列为标签为,存在y_data中
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index], self.y_data[index]
        
    def __len__(self): 
        # 返回文件数据的数目
        return self.len

4.TensorDataset

TensorDataset是Dataset的子类,已经复写了__len__和__getitem__方法,只需传入张量即可。

class TensorDataset(Dataset):
    """Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
 
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)
 
    def __len__(self):
        return self.tensors[0].size(0)
  • 比如: 

可以看出我们把X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是【TensorDataset】

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

5.Dataloader

DataLoader是Pytorch中用来处理模型输入数据的一个工具类。组合了数据集(dataset) + 采样器(sampler),并在数据集上提供单线程或多线程(num_workers )可迭代对象

  • 基本概念:
  1. epoch:      所有的训练样本输入到模型中称为一个epoch;
  2. iteration:  一批样本输入到模型中,成为一个Iteration;
  3. batchszie:批大小,决定一个epoch有多少个Iteration;

          迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)

  • 函数原型:
torch.utils.data.DataLoader(dataset, batch_size=1, 
    shuffle=False, sampler=None, 
    batch_sampler=None, num_workers=0, 
    collate_fn=None, pin_memory=False, 
    drop_last=False, timeout=0, 
    worker_init_fn=None, multiprocessing_context=None)
  • 参数:
  1. dataset (Dataset) – 决定数据从哪读取或者从何读取;

  2. batch_size (python:int, optional) – 批尺寸(每次训练样本个数,默认为1)

  3. shuffle (bool, optional) –每一个 epoch是否为乱序 (default: False).

  4. num_workers (python:int, optional) – 是否多进程读取数据(默认为0);

  5. drop_last (bool, optional) – 当样本数不能被batchsize整除时,最后一批数据是否舍弃(default: False)

  6. pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false) 

参考:

PyTorch学习之路(level2)——自定义数据读取_accimage_loader_AI之路的博客-CSDN博客

Pytorch(五)入门:DataLoader 和 Dataset_pytorch dataset_嘿芝麻的博客-CSDN博客

https://www.cnblogs.com/yongjieShi/p/10456802.html

https://www.cnblogs.com/ranjiewen/p/10128046.html

Python 子类继承父类构造函数:Python 子类继承父类构造函数说明 | 菜鸟教程

https://www.ziiai.com/blog/259

Python可迭代对象,迭代器,生成器的区别:Python可迭代对象,迭代器,生成器的区别_可迭代对象是什么意思_jinixin的博客-CSDN博客

完全理解Python迭代对象、迭代器、生成器:完全理解Python迭代对象、迭代器、生成器 - FooFish

Pytorch中的数据加载艺术:http://studyai.com/article/11efc2bf

PyTorch 数据集(Dataset):https://geek-docs.com/pytorch/pytorch-tutorial/pytorch-dataset.html

https://www.cnblogs.com/marsggbo/p/11308889.html

  • 85
    点赞
  • 216
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
### 回答1: DatasetDataLoaderPyTorch 中用于加载和处理数据的两个主要组件。Dataset 用于从数据源中提取和加载数据,DataLoader 则用于将数据转换为适合机器学习模型训练的格式。 ### 回答2: 在PyTorch中,DatasetDataLoader是用于处理和加载数据的两个重要类。 Dataset是一个抽象类,用于表示数据集对象。我们可以自定义Dataset子类来处理我们自己的数据集。通过继承Dataset类,我们需要实现两个主要方法: - __len__()方法:返回数据集的大小(样本数量) - __getitem__(idx)方法:返回索引为idx的样本数据 使用Dataset类的好处是可以统一处理训练集、验证集和测试集等不同的数据集,将数据进行一致的格式化和预处理。 DataLoader是一个实用工具,用于将Dataset对象加载成批量数据。数据加载器可以根据指定的批大小、是否混洗样本和多线程加载等选项来提供高效的数据加载方式。DataLoader是一个可迭代对象,每次迭代返回一个批次的数据。我们可以通过循环遍历DataLoader对象来获取数据。 使用DataLoader可以实现以下功能: - 数据批处理:将数据集划分为批次,并且可以指定每个批次的大小。 - 数据混洗:可以通过设置shuffle选项来随机打乱数据集,以便更好地训练模型。 - 并行加载:可以通过设置num_workers选项来指定使用多少个子进程来加载数据,加速数据加载过程。 综上所述,DatasetDataLoaderPyTorch中用于处理和加载数据的两个重要类。Dataset用于表示数据集对象,我们可以自定义Dataset子类来处理我们自己的数据集。而DataLoader是一个实用工具,用于将Dataset对象加载成批量数据,提供高效的数据加载方式,支持数据批处理、数据混洗和并行加载等功能。 ### 回答3: 在pytorch中,Dataset是一个用来表示数据的抽象类,它封装了数据集的访问方式和数据的获取方法。Dataset类提供了读取、处理和转换数据的功能,可以灵活地处理各种类型的数据集,包括图像、语音、文本等。用户可以继承Dataset类并实现自己的数据集类,根据实际需求定制数据集。 Dataloader是一个用来加载数据的迭代器,它通过Dataset对象来获取数据,并按照指定的batch size进行分批处理。Dataloader可以实现多线程并行加载数据,提高数据读取效率。在训练模型时,通常将Dataset对象传入Dataloader进行数据加载,并通过循环遍历Dataloader来获取每个batch的数据进行训练。 DatasetDataloader通常配合使用,Dataset用于数据的读取和预处理,Dataloader用于并行加载和分批处理数据。使用DatasetDataloader的好处是可以轻松地处理大规模数据集,实现高效的数据加载和预处理。此外,DatasetDataloader还提供了数据打乱、重复采样、数据划分等功能,可以灵活地控制数据的访问和使用。 总之,DatasetDataloaderpytorch中重要的数据处理模块,它们提供了方便的接口和功能,用于加载、处理和管理数据集,为模型训练和评估提供了便利。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

火柴的初心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值