知识点回顾
yield
的是什么?yield
在函数中的作用相当于是return- 不同的是,函数调用
return
返回后,下一次调用函数,函数从头开始执行;yield
就是返回一个值,并且记住这个返回的位置,下次迭代就从这个位置后开始. - 一个小例子
yield
有两种用法,配置for
循环使用或者配合next
使用。
def yield_fun():
for i in range(10):
yield i
if __name__ == "__main__":
fun1 = yield_fun()
print('配合for使用')
# 打印0~9
for i in fun1:
print(i)
print('配合next使用')
fun2 = yield_fun()
# 打印0
print(next(fun2))
# 打印1
print(next(fun2))
- data_loder是什么?怎么用?
data_loader一般在训练使用,类似下面的代码
for epoch in range(max_epoch):
data_loader = data_provider(x,y,batch_sz,shuffle=True)
for batct_x,batch_y in data_loader:
pre_y = net(batct_x)
loss = loss_fun(pre_y,batch_y)
loss.backward()
optimizer.step()
自己的data_loader
看起来data_loader就是yield配合for来用。
确实python自己定义了data_loader,可以每次返回一个batch的数据。
但是我在返回每个batch的时候需要对数据进行处理:每次除了返回batch的数据,还需要从一个字典中获取一个batch的数据。不知道怎么使用官方的,所以就定义一个自己的。
import random
def data_provider(x1,x2_index,x2_dict,y,batch_sz,shuffle=True):
#需要打乱顺序,shuffle为true
if(shuffle):
shuffle_list = list(zip(x1,x2_index,y))
random.shuffle(shuffle_list)
x1,x2_index,y =map(np.array,zip(*shuffle_list))
for i in range(x1.shape[0]):
x2 =list()
for j in range(batch_sz):
if(i+j == x1.shape[0]):
yield x1[i:i+j,:],np.array(x2),y[i:i+j,:]
#表示没有数据了,结束
return
else:
x2.append(x2_dict[x2_index[i+j]])
yield x1[i:i+batch_sz,:],np.array(x2),y[i:i+batch_sz,:]
使用自己的data_loader
data_loader = data_provider(x1,x2_index,x2_dict,y,batch_sz,shuffle=True)
for batct_x1,batct_x2,batch_y in data_loader:
...