dataset__getitem___PyTorch源码解析与实践(1):数据加载Dataset,Sampler与DataLoader

50ec567270794289e189b80081d26f50.png

66ed24462aed085e221df4854870c48c.png

献给学习PyTorch在路上或者计划较深入理解PyTorch的同行者们

写在前面

笔者一直使用tf,大势所趋决定转PyTorch,这个系列就作为我学习PyTorch的笔记与心得。

网络上PyTorch学习资源多如牛毛,因此我想写点不一样的。

我的计划:每篇文章介绍PyTorch一个/多个模块,内容主要为:

源码解析+实践demo

希望在代码层面更加透彻地理解和使用PyTorch(初步计划,后期可能会有改动)


1 源码解析

PyTorch的数据加载模块,一共涉及到Dataset,Sampler,Dataloader三个类

Dataset负责对raw data source封装,将其封装成Python可识别的数据结构,其必须提供提取数据个体的接口。Dataset共有Map-style datasets和Iterable-style datasets两种:

  • map-style dataset:实现了__getitem__和__len__接口,表示一个从索引/key到样本数据的map。比如:datasets[10],就表示第10个样本。

  • iterable-style dataset:实现了__iter__接口,表示在data samples上的一个Iterable(可迭代对象),这种形式的dataset非常不适合随机存取(代价太高),但非常适合处理流数据。比如:iter(datasets)获得迭代器,然后不断使用next迭代从而实现遍历。

Sampler负责提供一种遍历数据集所有元素索引的方式。

Dataloader负责加载数据,同时支持map-style和iterable-style Dataset,支持单进程/多进程,还可以设置loading order, batch size, pin memory等加载参数。

这三者的关系就一目了然了。

  1. 设置Dataset,将数据data source包装成Dataset类,暴露提取接口。

  2. 设置Sampler,决定采样方式。我们是能从Dataset中提取元素了,还是需要设置Sampler告诉程序提取Dataset的策略。

  3. 将设置好的Dataset和Sampler传入DataLoader,同时可以设置shuffle,batch_size等参数。使用DataLoader对象可以快捷方便地在给定数据集上遍历。

总结来说,即Dataloader负责总的调度,命令Sampler定义遍历索引的方式,然后用索引去Dataset中提取元素。于是就实现了对给定数据集的遍历。

1.1 Dataset

torch.utils.data.Dataset:抽象基类

  • 所有的Dataset相关类都应该继承自torch.utils.data.Dataset这个类

  • 这些子类必须要实现方法__getitem__(),来支持可以给定一个key(即索引)来获取对应的数据样本

  • 这些类可以实现方法__len__(),来返回数据集的大小规模

Dataset实现非常简洁,就只是提供了__getitem__ 和 __add__这两个接口。

前者很重要,是Dataset及其子类的核心,定义了数据元素提取(即通过索引获取样本,实际代码中常使用[]输入索引)

class Dataset(object):def __getitem__(self, index):        raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

具体实践中,我们需要使用Dataset的子类,自己实现的或者现成的。

我们可以来看看PyTorch为我们提供的现成的Dataset子类:

  • TensorDataset

  • IterableDataset

  • ConcatDataset

  • ChainDataset

  • Subset

下面着重介绍TensorDataset和IterableDataset.

CLASS torch.utils.data.TensorDataset(*tensors)

包装了Tensor的Dataset子类,map-style dataset

每个样本可以通过tensors第一个维度的索引获取

class TensorDataset(Dataset):r"""    Arguments:        *tensors (Tensor): tensors that have the same size of the first dimension.    """def __init__(self, *tensors):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)        self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in self.tensors)def __len__(self):return self.tensors[0].size(0)

如上源码:

__init__的形参是*tensors,因此是可以传入多个tensor变量的,但需要保证每个tensor的第一个维度均是一样的。

栗子:输入张量:

正确输入:640?wx_fmt=svg

错误输入:640?wx_fmt=svg

__getitem__提取的就是*tensors中每个张量的第index个样本(因为每个张量第一维度都是一样的)

__len__即*tensors每个张量第一个维度长度

常见用法:*tensors指定我们可以输入多个张量,我们可以同时输入train_data和train_label

dataset = TensorDataset(train_data, train_label)

CLASS torch.utils.data.IterableDataset

内部样本的组织形式是Iterable的所有dataset类都是IterableDataset类的子类,即:所有iterable-style dataset都是IterableDataset的子类。

这种形式的dataset对于处理流数据是非常有用的。

所有这些子类需要实现__iter__方法(而不是__getitem__方法了),需要据此来返回样本的迭代器,从而遍历dataset(实际代码中常使用iter+next来遍历)

关于Python中Iterable和Iterator的介绍见我的另一篇文章: 刘昕宸:彻底搞懂Python的__iter__和__next__,Iterable和Iteration
class IterableDataset(Dataset[T_co]):def __iter__(self) -> Iterator[T_co]:raise NotImplementedErrordef __add__(self, other: Dataset[T_co]):return ChainDataset([self, other])

关于多进程的问题:

IterableDataset的某个子类被DataLoader使用时,dataset中的每个item可以通过DataLoader的Iterator迭代获取。

当num_works>0时就是多进程模式,每个工作进程都有一个不同的dataset对象的拷贝,因此我们需要独立安排每一份拷贝该如何处理(后面会有例子),以防止不同的进程会返回重复的元素。(有MPI编程经验的同学应该更能理解!)

可以通过get_worker_info方法,在某一当前进程中调用,获得当前进程信息。这个方法要么在dataset类的__iter__方法中使用,要么在DataLoader的worker_init_fn方法中设置并使用。

举2个例子(来自官网文档):

例1:在dataset类的__iter__方法中使用get_worker_info方法,划分工作空间,获得当前进程id,并根据进程id分配其需要处理的工作空间

class MyIterableDataset(torch.utils.data.IterableDataset):def __init__(self, start, end):super(MyIterableDataset).__init__()        assert end > start, "this example code only works with end >= start"self.start = startself.end = enddef __iter__(self):        worker_info = torch.utils.data.get_worker_info()if worker_info is None:  # 单进程:一个进程处理全部样本            iter_start = self.start            iter_end = self.endelse: # 多进程,在当前进程中# 划分工作空间        
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值