python的迭代器_Python-迭代器与生成器

迭代器

迭代是Python最强大的功能之一,是访问集合元素如list,tuple的一种方式。迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。

迭代器有两个基本的方法:iter() 和 next()。

迭代器最常见的就是引用于深度学习中,一个batch一个batch的读数据,避免内存爆掉的情况,比如Pytorch中最常用的Dataloader(差点忘了经常自定义的dataloader是迭代器的原理)

如:

迭代器对象可以使用常规for语句进行遍历,不需要担心越界问题,因为迭代器内部存在着迭代终止标记StopIteration。

迭代器有什么好处?

在for循环中,如果是python2,则需要系统一次性分配相应长度的空间给list,而在python3中,只需要一个range对象,所需值由这个对象内置的next方法迭代返回,节省了空间。

创建一个迭代器

想要自定义一个迭代器需要在你的类中实现两个方法:iter() 与 next() 。

iter() 方法返回一个特殊的迭代器对象, 这个迭代器对象实现了 next() 方法并通过 StopIteration 异常标识迭代的完成。

next() 方法会返回迭代器的输出。

class MyNumbers:

def __iter__(self):

self.a = 1

return self

def __next__(self):

x = self.a

self.a += 1

return x

myclass = MyNumbers()

myiter = iter(myclass)

print(next(myiter))

print(next(myiter))

print(next(myiter))

print(next(myiter))

print(next(myiter))

1

2

3

4

5

当然,这样一个迭代器是危险的,因为其没有判断什么时候迭代停止,如果运用在for循环中会导致无限循环。这时候就需要一个StopIteration

StopIteration 用于标识迭代的完成,防止出现无限循环的情况,在 next() 方法中我们可以设置在完成指定循环次数后触发 StopIteration 异常来结束迭代。

class MyNumbers:

def __iter__(self):

self.a = 1

return self

def __next__(self):

if self.a <= 20:

x = self.a

self.a += 1

return x

else:

raise StopIteration

myclass = MyNumbers()

myiter = iter(myclass)

生成器

在 Python 中,使用了 yield 的函数被称为生成器(generator)。

跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。

在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。

调用一个生成器函数,返回的是一个迭代器对象。

import sys

def fibonacci(n): # 生成器函数 - 斐波那契

a, b, counter = 0, 1, 0

while True:

if (counter > n):

return

yield a

a, b = b, a + b

counter += 1

f = fibonacci(10) # f 是一个迭代器,由生成器返回生成

while True:

try:

print (next(f), end=" ")

except StopIteration:

sys.exit()

Pytorch中的生成器的应用

自定义数据加载器的时候,我们需要继承Dataset类并重写__len__和__getitem__ 这两个方法

import torch

class Dataset(torch.utils.data.Dataset):

'Characterizes a dataset for PyTorch'

def __init__(self, list_IDs, labels):

'Initialization'

self.labels = labels

self.list_IDs = list_IDs

def __len__(self):

'Denotes the total number of samples'

return len(self.list_IDs)

def __getitem__(self, index):

'Generates one sample of data'

# Select sample

ID = self.list_IDs[index]

# Load data and get label

X = torch.load('data/' + ID + '.pt')

y = self.labels[ID]

return X, y

#生成数据generator

training_set = Dataset(partition['train'], labels)

training_generator = torch.utils.data.DataLoader(training_set, **params)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值