#Python tensorflow中getbatch函数(yield实现),不会stopinteraction

在深度学习的入门教程中,很多深度学习的模型都是用手写数字mnist数据集进行训练的,在使用过程中通常都有一个batch的分批处理,类似这个:

在这里插入图片描述
这个next_batch函数是tensorflow中的函数,我们直接找源码过去copy也不太现实,我们就按照大概的方法写一个,如下:

def get_batches(x, y, n_batches):
	batch_size = len(x) // n_batches

	ii = 0
	while ii < n_batches * batch_size:
		# 判断如果这不是最后一个batch,那么这个batch中应该有batch_size个数据
		if ii != (n_batches - 1) * batch_size:
			X, Y = x[ii: ii + batch_size], y[ii: ii + batch_size]
		# 如果是最后一个batch,则剩余不够batch_size的数据都要凑入一个batch中
			ii += batch_size
		else:
			X, Y = x[ii: ], y[ii: ]
			# 能走到这一步说明数据已经取完了,为了避免抛出异常可以把ii设置为0,继续while循环
			ii = 0
		# 生成器语法,返回X, Y
		yield X, Y

这样函数就写好了,接下来就是调用的时候:

# 首先在外部定义batch
batch1 = get_batches(data_x_train, data_y_train, n_batches)
# 在循环中不断的get batch,可以一直取,不会stopinteraction
for i in range(training_step):
	x_batch, y_batch = batch1.__next__()
	# 然后x_batch和y_batch就可以看需要是否要reshape一下,然后就放进去训练了
	x_batch = x_batch.values.reshape([-1, n_steps, n_inputs]) 
	# 我在跑LSTM的数据格式可能跟你们不一样

纯手打无粘贴,如有错误请评论或联系我

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值