python 使用yield实现自己的data_loader

知识点回顾

  • yield的是什么?
    • yield在函数中的作用相当于是return
    • 不同的是,函数调用return返回后,下一次调用函数,函数从头开始执行;yield就是返回一个值,并且记住这个返回的位置,下次迭代就从这个位置后开始.
    • 一个小例子
      yield有两种用法,配置for循环使用或者配合next使用。
def yield_fun():
    for i in range(10):
        yield i

if __name__ == "__main__":
    fun1  = yield_fun()
    print('配合for使用')
    # 打印0~9
    for i in fun1:
        print(i) 

    print('配合next使用')
    fun2 = yield_fun()
    # 打印0
    print(next(fun2))
    # 打印1
    print(next(fun2))
  • data_loder是什么?怎么用?
    data_loader一般在训练使用,类似下面的代码
for epoch in range(max_epoch):
	data_loader = data_provider(x,y,batch_sz,shuffle=True)
	for batct_x,batch_y in data_loader:
		pre_y = net(batct_x)
		loss = loss_fun(pre_y,batch_y)
		loss.backward()
		optimizer.step()

自己的data_loader

看起来data_loader就是yield配合for来用。
确实python自己定义了data_loader,可以每次返回一个batch的数据。
但是我在返回每个batch的时候需要对数据进行处理:每次除了返回batch的数据,还需要从一个字典中获取一个batch的数据。不知道怎么使用官方的,所以就定义一个自己的。

import random
def data_provider(x1,x2_index,x2_dict,y,batch_sz,shuffle=True):
	#需要打乱顺序,shuffle为true
	if(shuffle):
		shuffle_list = list(zip(x1,x2_index,y))
		random.shuffle(shuffle_list)
		x1,x2_index,y =map(np.array,zip(*shuffle_list))
	for i in range(x1.shape[0]):
		x2 =list()
		for j in range(batch_sz):
			if(i+j == x1.shape[0]):
				yield x1[i:i+j,:],np.array(x2),y[i:i+j,:]
				#表示没有数据了,结束
				return
			else:
				x2.append(x2_dict[x2_index[i+j]])
		yield x1[i:i+batch_sz,:],np.array(x2),y[i:i+batch_sz,:]

使用自己的data_loader

data_loader = data_provider(x1,x2_index,x2_dict,y,batch_sz,shuffle=True)
for batct_x1,batct_x2,batch_y in data_loader:
		...
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值