PyTorch:Dataset()与Dataloader()的使用详解

目录

1、Dataset类的使用

2、Dataloader类的使用

3、总结


Dataset类与Dataloader类是PyTorch官方封装的用于在数据集中提取一个batch的训练用数据的接口,其实我们也可以自定义获取每个batch的方法,但是对于大数据量的数据集,直接用封装好的接口会很大程度上提升效率。

一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责在整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。

1、Dataset类的使用

Dataset用以整理数据集。我们整理数据的目的是为了Dataloader可以方便的从整理后的和数据中获取一个batch的数据来供网络进行训练。

先看一下官方的Dataset的源码:

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

很明显,这个类内部什么方法的实现都没有,就是用来让我们继承重写的。当我们继承该类时,必须重写里面的__getitem__(self, index)方法。该方法定义了使用索引值来查找元素的方法,即假如我们定义一个自己的训练数据集实例traindata,如果想使用traindata[index]的方式来获取索引为index的数据,我们就得实现__getitem__方法。这样当我们调用traindata[index]索引数据时,其实就是自动调用__getitem__(self, index)方法来实现的。另外,我们还可以重写__len__(self)方法,用以使用len(traindata)方法来获取我们整个数据集的数量。如果还不清楚,可以细细品一下下面的例子:

class TrainData(Dataset):  # 继承Dataset类并重写相关的方法
    ...
    def __getitem__(self, index):
        '''编写自己的数据获取方式'''
        return [x_data, y_lable]

    def __len__(self):
        '''编写获取数据集大小的实现方式'''
        return length


traindata = TrainData(mydataset)   # 定义一个实例
first = traindata[0]       # 获取数据集中的第一组数据,会自动调用__getitem__
length = len(traindata)    # 获取数据集的数据量的方法,会自动调用__len__

2、Dataloader类的使用

整体上来说,Dataloader类就是从上面封装好的数据中按照给定的方式来一次一次地抽取一个batch的数据来供网络进行训练,其内部使用的是yield生成器机制。Dataloader不用继承重写,我们直接实例化就行。下面我们接着上面的例子来继续了解下Dataloader从数据集中取出一个batch数据的过程:

首先,定义一个Dataloader实例gen_train:

gen_train = Dataloader(traindata, batch_size=4, num_workers=4, pin_memory=True, drop_last=True, collate_fn=my_collate_fn)

关于有关参数的说明(没用到的参数就不解释了):

1、traindata(Dataset): 传入的数据集,按自己定义的Dataset实例名来传入,我这里是traindata
2、batch_size(int, optional): 每个batch有多少个样本
3、num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
4、pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
5、drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为4,而一个epoch只有100个样本,那么训练的时候后面的2个因为不满足组成一个batch就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
6、collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

可以看到,gen_train从traindata中返回的是一个含有batch_size(4)个数据([x_data, y_label])的mini_batch。

下面我们分析分析这个过程是咋实现的。首先,DataLoader(object)源码中有下面这么一段代码:

                。。。。。。
if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)
                。。。。。。

按照上面的设置,sampler默认是None,我们没有定义要打乱数据(即shuffle为False),则接下来会调用

sampler = SequentialSampler(dataset)

再来看看这个方法是怎么实现的:

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

主要看__iter__部分,明显的,假设数据集共有n个数据,这是一个返回的sampler就是数据集长度[0,1,2,......,n-1]序号的迭代器。关于怎么迭代,我们回到DataLoader(object)源码中继续往先看,会发现这么几条代码:

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

首先说一下,这个代码就是从上一步的迭代器sampler中取出batch_size个序号,batch_size之前我们设置的是4,所以就是取出4个序号(索引),用以后面从traindata中取出batch_size个数据,来看一下BatchSampler方法的迭代方式的实现,注意这里的yield机制:

class BatchSampler(Sampler):
        。。。。。。
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
        。。。。。。

所以,到这里我们一个batch_size的数据的索引就已经有了,后面就是调用多线程或单线程机制来取出对应的数据traindata[i]了。回到DataLoader(object)源码中,在往下看,就是下面这段代码了:

def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

这段代码就是DataLoader的迭代器的实现方式了,具体的单多线程实现就不详细展开了。此时我们已经完成了获取本次迭代所需要的数据的索引值,接下来即使按照索引在traindata中找到相应的数据并一起返回这个mini_batch了。比如我们可以这样获取数据并用于训练:

for iteration, batch in enumerate(gen_train):
    if iteration >= epoch_size:  # 判断是否到达一个epoch的迭代次数(len(traindata)/batchsize)
        break
    x_datas, y_labels= batch[0], batch[1]  # 获取batch中的数据和标签,用于训练
                ......

我们就可以使用这批数据进行一次网络的训练了,这么周而复始,直至达到我们设置的epoch。

3、总结

一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责从Dataset整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。

  • 6
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorchDataset是一个用来构造支持索引的数据集的类。但是Dataset类本身不能实例化,所以在使用Dataset时,我们需要定义自己的数据集类,也就是继承自Dataset类的子类,来继承Dataset类的属性和方法。这样我们就可以根据自己的数据集的具体需求来实现自定义的数据加载逻辑,例如读取图片、文本等,并将其转换为模型能够接受的格式。通过使用Dataset和自定义的数据集类,我们可以方便地加载和处理数据,为模型的训练提供输入。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [PyTorch基础之数据模块DatasetDataLoader用法详解(附源码)](https://blog.csdn.net/jiebaoshayebuhui/article/details/130439027)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorch 解决DatasetDataloader遇到的问题](https://download.csdn.net/download/weixin_38746515/12856519)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地球被支点撬走啦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值