add函数 pytorch_PyTorch 源码解读之 torch.utils.data

点击上方“AIWalker”,选择加"星标"重磅干货,第一时间送达作者:OpenMMLab知乎:https://zhuanlan.zhihu.com/p/337850513本文已获作者授权转载,不得擅自二次转载0 前言本文涉及的源码以 PyTorch 1.7 为准迭代器理解 Python 的迭代器是解读 PyTorch 中torch.utils.data模块的关键。在Datase...
摘要由CSDN通过智能技术生成

 点击上方“AIWalker”,选择加"星标"

重磅干货,第一时间送达8d8bb5bb6d7f2e1139c956e5a5aa3ad4.png

作者:OpenMMLab
知乎:https://zhuanlan.zhihu.com/p/337850513
本文已获作者授权转载,不得擅自二次转载

0 前言

本文涉及的源码以 PyTorch 1.7 为准

迭代器

理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键。

在 DatasetSampler 和 DataLoader 这三个类中都会用到 python 抽象类的魔法方法,包括__len__(self)__getitem__(self) 和 __iter__(self)

  • __len__(self): 定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数

  • __getitem__(self): 定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作

  • __iter__(self): 定义当迭代容器中的元素时的行为

迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典,这些数据结构都支持迭代操作。

实现迭代器的魔法方法有两个:__iter__(self) 和 __next__(self)

一个容器如果是迭代器,那就必须实现 __iter__(self) 魔法方法,这个方法实际上是返回是一个迭代器(通常是迭代器本身)。接下来重点要实现的是 __next__(self) 魔法方法,因为它决定了迭代的规则。

class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a + self.b
if self.a > self.n:
raise StopIteration
return self.a

fibs = Fibs()
for each in fibs:
print(each)

# 输出
# 1 1 2 3 5 8 13

一般来说,迭代器满足以下几种特性:

  • 迭代器是⼀个对象

  • 迭代器可以被 next() 函数调⽤,并返回⼀个值

  • 迭代器可以被 iter() 函数调⽤,并返回一个迭代器(可以是自身)

  • 连续被 next() 调⽤时依次返回⼀系列的值

  • 如果到了迭代的末尾,则抛出 StopIteration 异常

  • 迭代器也可以没有末尾,只要被 next() 调⽤,就⼀定会返回⼀个值

  • Python 中, next() 内置函数调⽤的是对象的 next() ⽅法

  • Python 中, iter() 内置函数调⽤的是对象的 iter() ⽅法

  • ⼀个实现了迭代器协议的的对象可以被 for 语句循环迭代直到终⽌

了解了什么是迭代器后,我们就可以开始解读 torch.utils.data 模块

对于 torch.utils.data 而言,重点是其 DatasetSamplerDataLoader 模块,辅以 collatefetchpin_memory 等组件对特定功能予以支持。

1 Dataset

Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。

Dataset 共有 Map-style datasets 和 Iterable-style datasets 两种:

1.1 Map-style dataset

torch.utils.data.Dataset

它是一种通过实现 __getitem__() 和 __len()__ 来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx] 访问 idx 对应的数据。

通常我们使用 Map-style 类型的 dataset 居多,其数据接口定义如下:

class Dataset(Generic[T_co]):
# Generic is an Abstract base class for generic types.

def __getitem__(self, index) -> T_co:
raise NotImplementedError

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])

PyTorch 中所有定义的 Dataset 都是其子类。

对于一般计算机视觉任务,我们通常会在其中进行一些 resize, crop, flip 等预处理的操作

值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点在其源码中也有注释加以体现。

1.2 Iterable-style dataset

torch.utils.data.IterableDataset

它是一种实现 __iter__() 来获取数据的 Dataset,这种类型的数据集特别适用于以下情况:随机读取代价很大甚至不大可能,且 batch size 取决于获取的数据。其接口定义如下:

class IterableDataset(Dataset[T_co]):

def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError

def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])

特别地,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)

1.3 其他 Dataset

除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类

  • torch.utils.data.ConcatDataset: 用于连接多个 ConcatDataset 数据集

  • torch.utils.data.ChainDataset : 用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用

  • torch.utils.data.Subset: 用于获取指定一个索引序列对应的子数据集

class Subset(Dataset[T_co]):

dataset: Dataset[T_co]
indices: Sequence[int]

def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

def __len__(self):
return len(self.indices)
  • torch.utils.data.TensorDataset: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。

class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors

def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors

def __len__(self):
return self.tensors[0].size(0)

2 Sampler

torch.utils.data.Sampler 负责提供一种遍历数据集所有元素索引的方式。可支持用户自定义,也可以用 PyTorch 提供的,基类接口定义如下:

lass Sampler(Generic[T_co]):
r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """

def __init__(self, data_source: Optional[Sized]) -> None:
pass

def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError

特别地,__len__()

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值