目录
一、python迭代器生成器基础讲解
1.1可迭代对象Iterable
表示该对象可迭代,并不一定是一个数据类型,如字典,字符串,列表等,它也可以是一个实现了__iter__方法的类。
from collections.abc import Iterable, Iterator
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __iter__(self):
# 此处返回啥无所谓
return self.a
cls_a = A()
# True
print(isinstance(cls_a, Iterable))
如果对象是Iterable,依然无法用for循环遍历,因为Iterable仅仅是提供了一种抽象规范接口。
1.2迭代器Iterator
如果一个对象是迭代器,那么它肯定是可迭代的,但是如果一个对象是可迭代的,它不一定是迭代器。实现了 __next__
和 __iter__
方法的类才能称为迭代器,就可以被 for 遍历了。
class A(object):
def __init__(self):
self.index = -1
self.a = [1, 2, 3]
# 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
# 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
def __iter__(self):
return self
def __next__(self):
self.index += 1
if self.index < len(self.a):
return self.a[self.index]
else:
# 抛异常,for 内部会自动捕获,表示迭代完成
raise StopIteration("遍历完了")
cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
print(isinstance(iter(cls_a), Iterator)) # True
for a in cls_a:
print(a)
# 打印 1 2 3
1.3for in 的本质流程
for.....in...被python编译器编译后,如下
# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
cls_a = iter(cls_a)
while True:
try:
# 然后调用对象的 __next__ 方法,不断返回元素
value = next(cls_a)
print(value)
# 如果迭代完成,则捕获异常即可
except StopIteration:
break
可见,任何一个对象要能被for遍历,必须实现__iter__和__next__两个方法。
list是可迭代对象,但是没next方法,为什么可以实现for循环遍历。list内部的iter方法的内部实现了next方法。
所以得到:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 __iter__
和 __next__
方法(对象必然是 Iterator),还是只实现 __iter__
(不是 Iterator),但是内部间接返回了具备 __next__
对象的类,都是可行的。
1.4 getitem
上面说过for in本质就是调用__iter__和__next__方法,实际上还有一种更简单的方法,__getitem__方法就可以让对象实现迭代功能。实际上任何一个类,只要实现了__getitem__方法,那么当调用iter(类实例)时候会自动具备__iter__和__next__方法。__getitem__
实际上是属于 iter和
next方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __getitem__(self, item):
return self.a[item]
cls_a = A()
print(isinstance(cls_a, Iterable)) # False
print(isinstance(cls_a, Iterator)) # False
print(dir(cls_a)) # 仅仅具备 __getitem__ 方法
cls_a = iter(cls_a)
print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
# 等价于 for .. in ..
while True:
try:
# 然后调用对象的 __next__ 方法,不断返回元素
value = next(cls_a)
print(value)
# 如果迭代完成,则捕获异常即可
except StopIteration:
break
# 输出: 1 2 3
如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 __len__
方法即可。
此时我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。
1.5 yield 生成器
生成器是一个在行为上和迭代器非常类似的对象,两者功能差不多,但生成器更优雅,只需要用关键字yield来返回。作用于函数上叫生成器函数,调用函数返回一个生成器。
def func():
for a in [1, 2, 3]:
yield a
cls_g = func()
print(isinstance(cls_g, Iterator)) # True
print(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
for a in cls_g:
print(a)
# 输出: 1 2 3
# 一种更简单的写法是用 ()
cls_g = (i for i in [1,2,3])
使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。
总结:在迭代对象基础上,如果实现了 __next__
方法则是迭代器对象,该对象在调用 next() 的时 候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。
对于采用语法糖 __getitem__
实现的迭代器对象,其本身实例既不是可迭代对象,更不是 迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成 迭代器。
生成器是一种特殊迭代器,但是不需要像迭代器一样实现__iter__
和__next__
方法,只需要 使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。
对于在类的 __iter__
方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对 象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后 就会自动变成迭代器。
二、DataLoader的基础实现
首先介绍5个基本的对象:
Dataset提供整个数据集的随机访问功能,每次访问都返回单个对象,例如一个对象和一个target。
Sampler提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引。常用的子类是SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引
BatchSampler内部调用Sampler实列,输出指定batch_size个索引,然后将索引作用于Dataset上从而输出batch_size个数据对象,例如batch_size个数据和索引。
Collate_fn用于将batch_size个数据对象在batch维度进行聚合,生成(batch,.....)格式的数据输出。如果待聚合对象是numpy,则自动转化为tensor,此时就可以输入到网络中了。
迭代一次伪代码如下(非迭代器版本)
class DataLoader(object):
def __init__(self):
#假设数据长度为100,batch_size是4
self.dataset=[[img0,target0],[img1,target1],.....[img99,target99]]
self.sampler=[0,1,2,.....,99]
self.batch_size=4
self.index=0
def collate_fn(self,data):
#在batch维度聚合数据
batch_img=torch.Stack(data[0],0)
batch_target=torch.stack(data[1],0)
return batch_img,batch_target
def __next__(self):
i=0
batch_index=[]
while i<self.batch_size:
#内部会调用sampler对象获取单个索引
batch_index.append(self.sampler[self.index])
self.index+=1
i+=1
#得到batch_size个索引之后,调用dataset对象
data=[self.dataset[idx] for idx in batch_index]
#调用collate_fn 在batch维度进行拼接输出
batch_data=self.collate_fn(data)
return batch_data
def __iter__(self):
return self
# torch.stack()是指将列表里面的张量进行扩维拼接
# data=[[torch.Tensor([1]),torch.Tensor([1])],[torch.Tensor([1]),torch.Tensor([1])]]
# print(torch.stack(data[0],0),torch.stack(data[1],0))
# data=[torch.Tensor([1,2,3]),torch.Tensor([4,5,6])]
# print(torch.stack(data))
以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。
首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器。
Dataset通过__getitem__方法使其可迭代
Sample对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 __iter__
内部返回迭代器,RandomSampler 在 __iter__
内部通过 yield 关键字返回迭代器
Batchsampler也是在__iter__内部通过yield关键字返回迭代器
DataLoader通过__iter__和__next__直接实现迭代器
除了DataLoader本身是迭代器外,其余对象本身都不是迭代器,但可以for in迭代
由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。
需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。
三、整体框架的讲解
核心运行逻辑:
def __next__(self):
#返回batch个索引
index=next(self.batch_sampler)
#利用索引去取数据
data=[self.dataset[idx] for idx in index]
#batch维度聚合
data=self.collate_fn(data)
return data
整体流程:
1.self.batch_sampler=iter(batch_sampler)。在DataLoader的类初始化,需要得到BatchSampler的迭代器对象。
2.index=next(self.batch_sampler)。对于每次迭代,DataLoader对象首先会调用BatchSampler的迭代器进行下一次迭代,具体是调用BatchSampler对象的__iter__方法
3.而BatchSampler对象的__iter__方法实际上是需要依靠Sampler对象进行迭代输出索引,Sampler对象也是一个迭代器,当迭代batch_size次后就可以得到batch_size个数据索引。
4.data=[self.dataset[idx] for idx in index]。有了batch个索引就可以通过不断调用dataset的__getitem__方法返回数据对象,此时data就包含了batch个对象。
5.data=self.collate_fn(data)。将batch个对象输入给聚合函数,在第0个维度也就是batch维度进行聚合,得到类似(batch,....)的对象。
6.重复上面的操作,就可以不断输出一个一个的batch数据
class Dataset(object):
#只要实现了__getitem__方法就可以变成迭代器
def __getitem__(self,index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class Sampler(object):
def __init__(self,data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
#一般出现raise NotImplementedError这个错误,就是子类没有重写父类中的成员函数,然后子类对象调用此函数会报这个错误
class SequentialSampler(sampler):
def __init__(self,data_source):
super(SequentialSampler,self).__init__(data_source)
self.data_source=data_source
def __iter__(self):
#返回迭代器,不然无法for in
return iter(range(len(self.data_source))
def __len__(self):
return len(self.data_source)
class BatchSampler(Sampler):
def __init__(self,sampler,batch_size,drop_last):
self.sampler=sampler
self.batch_size=batch_size
self.dorp_last=drop_last
def __iter__(self):
batch=[]
for idx in self.sampler:
batch.append(idx)
#如果得到了batch个索引,则可以通过yield关键字生成生成器返回,得到迭代器对象
if len(batch)==self.batch_size:
yield batch
batch=[]
if len(batch)>0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
#如果最后的索引数不等于一个batch,抛弃
return len(self.sampler)//self.batch_size
else:
return (len(self.sampler)+self.batch_size-1)//self.batch_size
class DataLoader(object):
def __init__(self,dataset,batch_size=1,shuffle=False,sample=None,batch_sampler=None,
collate_fn=None,drop_last=False):
self.dataset=dataset
#因为这两个功能是冲突的
if sampler is not None and shuffle:
raise ValueError('sampler option is ..')
if batch_sampler is not None:
# 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
# 和 drop_last 四个参数就不能传入
# 因为这4个参数功能和 batch_sampler 功能冲突了
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
batch_size = None
drop_last = False
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
# 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
if batch_sampler is None:
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = iter(batch_sampler)
if collate_fn is None:
collate_fn = default_collate
self.collate_fn = collate_fn
#核心代码
def __next__(self):
index=next(self.batch_sampler)
data=[self.dataset[idx] for idx in index]
data=self.collate_fn(data)
return data
#返回自身,因为自身实现了next
def __iter__(self):
return self
def default_collate(batch):
elem=batch[0]
elem_type=type(elem)
if isinstance(elem,torch.Tensor):
return torch.stack(batch,0)
elif elem_type.__module__=='numpy':
return default_collate([torch.as_tensor(b) for b in batch])
else:
raise NotImplementedError
完整调用例子
class Simplev1Dataset(Dataset):
def __init__(self):
#伪造数据
self.imgs=np.arange(0,16).reshape(8,2)
def __getitem__(self,index):
return self.imgs[index]
def __len__(self):
return self.imgs.shape[0]
from simplev1_dataset import Simplev1Dataset
simple_dataset=Simplev1Dataset()
dataloader=DataLoader(simple_dataset,batch_size=2,collate_fn=default_collate)
for data in dataloader:
print(data)