动机
在实验时,碰到了需要自定义Sampler的情况。因此,出于使自己放心的动机,对DataLoader的源码进行了分析,了解了DataLoader的内部运行机制,明白了其是如何通过Sampler来操作DataSet中的数据的,这下可以放心的写Sampler了!
相关概念
可迭代对象
- 可迭代对象是 使用内置函数
iter()
可以获取 迭代器 的对象,即- 要么对象实现了能返回迭代器的
__init__()
方法 - 要么对象实现了
__getitem__()
方法,而且其参数是从零开始的索引
- 要么对象实现了能返回迭代器的
- 内置的
iter()
函数有以下作用- 检查对象是否实现了
__init__()
方法,如果实现了就调用它,获取一个迭代器 - 如果没有实现
__init__()
方法,但是实现了__getitem__()
方法,而且其参数是从零开始的索引,Python 会创建一个迭代器,尝试按顺序(从索引 0 开始)获取元素 - 如果前面两步都失败,Python 抛出 TypeError 异常,通常会提示“C objectis not iterable”(C 对象不可迭代),其中 C 是目标对象所属的类
- 检查对象是否实现了
- Python内置 str、list、tuple、dict、set、file 都是可迭代对象
迭代器
- 可迭代对象执行
__iter__()
方法得到的返回值是迭代器 - 迭代器对象指的是即内置有
__iter__()
又内置有__next__()
方法的对象 - 迭代器对象一定是可迭代对象,而可迭代对象不一定是迭代器对象
- 标准的迭代器接口有两个方法,即:
__next__()
:返回下一个可用元素,如果没有元素,抛出StopIteration异常__iter__()
:返回self,以便在应该使用可迭代对象的地方使用迭代器,比如for循环中
- Python内置 file 是迭代器
for循环的内部机制
- 先判断对象是否为可迭代对象,即是否满足可迭代对象的定义,如果满足则使用
iter()
方法,返回一个迭代器;否则,直接抛出TypeError异常 - 不断地调用迭代器的
__next__
方法,每次调用按顺序迭代获取当前的值 - 迭代完所有元素,就抛出异常 StopIteration,这个异常 python 解释器自己会处理
源码分析
源码流程图
源码解析 Sampler
-
所有的采样器都继承自
Sampler
这个类 -
需重写三种方法
class MySampler(Sampler): r"""Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """ def __init__(self, data_source): pass // 通过该方法获取迭代器对象,可return返回一个迭代器,可yield得到一个生成器 def __iter__(self): raise NotImplementedError def __len__(self): // 返回数据的个数 raise NotImplementedError
参考Blog
https://www.cnblogs.com/marsggbo/p/11541054.html
https://www.cnblogs.com/marsggbo/p/11308889.html