pytorch的数据读取

pytorch的数据读取

pytorch数据读取的核心是torch.utils.data.DataLoader类,具有以下特性:

  • 支持map-style datasets和iterable-style datasets
  • 自定义数据读取顺序
  • 自动批量化
  • 单线程/多线程读取
  • 自动内存锁页

1. 整体流程

DataLoader的参数如下,主要涉及DataSetsamplecollate_fnpin_memory

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

pytorch读取数据的整体处理流程如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1CBP60mO-1620994522059)(C:\Users\sfang.fangsh\Desktop\v2-0b0b53c3c58e1269cd87b7ab42d9f221_r.jpg)]

无论是map-style还是iterable-style dataset整体流程都是

  1. 先用采样器采样,采样一次得到一个样本的索引(iterable-style dataset无法通过索引取值,所以使用的是一个虚假的采样器,每次生成None)。
  2. 使用batch_sampler生成长度为batch_size的索引列表(实际是使用sampler采样batch_size次)。
  3. 使用collate_fn将batch_size长度的列表整理成batch样本(tensor格式)。

2. DataSet Types

pytorch支持两种类型的数据集map-style datasetiterable-style dataset

Map-style datasets

字典型数据集是指实现了__getitem__()__len__()协议,表示从索引到数据样本的映射。

可以继承抽象类torch.utils.data.DataSet,并重写__getitem__()__len__()方法。

Note: DataLoader默认构造的采样器返回的都是整数索引,如果dataset的索引不是整数,需要自定义采样器。

Iterable-style datasets

可迭代型数据集是torch.utils.data.IterableDataset的子类,需要实现__iter__()协议,表示对数据样本的一轮迭代。

iterable-style dataset类似python的可迭代对象。使用iter()方法会得到一个迭代器,每次调用next()会得到下一个样本。无法使用索引取元素。所以就不能使用采样器采样得到索引,在使用索引得到样本。dataloader的实现中,对于可迭代类型的数据集会使用一个虚假采样器InfiniteConstantSampler。每次调用都返回None。

class _InfiniteConstantSampler(Sampler):
    r"""Analogous to ``itertools.repeat(None, None)``.
    Used as sampler for :class:`~torch.utils.data.IterableDataset`.

    Args:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self):
        super(_InfiniteConstantSampler, self).__init__(None)

    def __iter__(self):
        while True:
            yield None

这个采样器的目的就是为了在batch_sample时控制采样的次数。

3. Sampler

对于IterableDataset来说,数据读取的顺利是由用户定义迭代决定的。回想下python的迭代器,只能通过循环调用next()方法,依次拿到下一个样本。不能改变原有的次序。

对may-style Dataset来说,sampler用来在数据读取时,指定样本索引的顺序。可以指定DataLoader的shuffle参数来指导顺序读取还是乱序读取。如果shuffle=True,会自动构造一个RandomSampler采样器,shuffle=False,会构造SequentialSample采样器。也可以用户自定义一个采样器并使用sample参数指定。自定义采样器每次返回下一个采样的索引注意采样器返回的都是样本索引,不是样本本身。需要根据索引得到样本。

batch_sampler

如果一个采样器sampler一次返回批量大小的索引列表,那么就叫做batch_sampler。如果指定batch_size和drop_last参数,就会基于sampler(采样器)自动构造一个batch_sampler(批量采样器)。map-style 数据集也可以使用batch_sampler参数指定自定义的批量采样器。

4. collate_fn

collate_fn从字面上看就是整理函数,是对batch_sampler批量采样器返回的长度是batch_size的索引列表进行加工,处理成模型可以使用的batch_size大小的tensor。

这里需要注意采样器sampler/batch_sampler返回的都是样本索引,collate_fn的输入是批量大小的样本列表。所以在传给collate_fn前要根据索引取样本。

如果是may-style数据集,这个操作大概等价于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

如果是iterable-style数据集,大概等价于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以看到iterable-style的索引其实是没用的,只是用来控制采样的个数。同时,发现collate_fn函数接收的参数是样本列表。collate_fn的一个重要功能就是把这个列表加工成pytorch支持的数据格式tensor。通过看pytorch的源码,如果不指定collate_fn,会使用默认的collate_fn函数,这个函数的功能就是将各种类型的数据转化成tensor。也可以自定义collare_fn函数,然后通过collate_fn参数指定,在自定义的函数中增加需要的操作。例如,将每个样本padding到当前batch的最大样本长度。任何想要对批量数据进行的操作都要定义在这个函数中。

pytroch的默认collate_fn实现:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

5. pin_memory

这部分来自博客

pin_memory是指锁页内存,什么是锁页内存?

内存分为锁页和不锁页,锁页内存存的内容在任何情况下都不会与机器的虚拟内存(虚拟内存就是硬盘)进行交换。不锁页内存在主机内存不足时,数据会存放到虚拟内存。

如果pin_memory=True,那么生成的数据都会放在锁页内存上,此时将tensor拷贝到GPU的显存会更快。

6. 自动批量化/非批量化

dataloader默认返回批量的样本(batch_size默认为1)。当参数batch_size和batch_sample均为None时,会关闭自动批量化操作。此时会将采样的单个样本传给collate_fn函数。

参考

[1]pytorch官方文档 TORCH.UTILS.DATA部分

[2]博客:pytorch创建data.DataLoader时,参数pin_memory的理解

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值