模拟tensorflow中next_batch的实现原理来熟悉yield的应用。
yield 可以看成是return,函数执行到yield便会退出函数,返回一个生成器。想要得到返回值必须使用next()来获取,否则只会得到一个生成器。且下次调用,会从yield出执。
行。
下面看模拟该函数的运行的过程。
import numpy as np
def next_batch(train,target,batch_size):
length=len(train)
index=[i for i in range(length)]
np.random.shuffle(index)
cnt=length/batch_size+1
while cnt>0:
batch_x=[]
batch_y=[]
try:
for i in range(batch_size):
batch_x.append(train[index[i]])
batch_y.append(target[index[i]])
index.remove(index[i])
except:
print("结束")
yield (batch_x,batch_y)
train=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
target=[0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1]
a=next_batch(train,target,3)
for i in range(7):
print(next(a))
输出的结果: ([5, 12, 18], [0, 1, 1]) ([1, 14, 11], [0, 1, 0]) ([7, 19, 2], [0, 0, 1]) ([15, 9, 3], [0, 0, 0]) ([20, 4, 6], [1, 1, 1]) ([8, 10, 13], [1, 1, 0]) 结束 ([16], [1])
debug该函数的实现过程
(1)第一次执行时,batch-x和batch-y分别为(10,1,12),(1,0,1),从yield处返回。
2.再次调用next(),直接从while()处执行。