在深度学习的入门教程中,很多深度学习的模型都是用手写数字mnist数据集进行训练的,在使用过程中通常都有一个batch的分批处理,类似这个:
这个next_batch函数是tensorflow中的函数,我们直接找源码过去copy也不太现实,我们就按照大概的方法写一个,如下:
def get_batches(x, y, n_batches):
batch_size = len(x) // n_batches
ii = 0
while ii < n_batches * batch_size:
# 判断如果这不是最后一个batch,那么这个batch中应该有batch_size个数据
if ii != (n_batches - 1) * batch_size:
X, Y = x[ii: ii + batch_size], y[ii: ii + batch_size]
# 如果是最后一个batch,则剩余不够batch_size的数据都要凑入一个batch中
ii += batch_size
else:
X, Y = x[ii: ], y[ii: ]
# 能走到这一步说明数据已经取完了,为了避免抛出异常可以把ii设置为0,继续while循环
ii = 0
# 生成器语法,返回X, Y
yield X, Y
这样函数就写好了,接下来就是调用的时候:
# 首先在外部定义batch
batch1 = get_batches(data_x_train, data_y_train, n_batches)
# 在循环中不断的get batch,可以一直取,不会stopinteraction
for i in range(training_step):
x_batch, y_batch = batch1.__next__()
# 然后x_batch和y_batch就可以看需要是否要reshape一下,然后就放进去训练了
x_batch = x_batch.values.reshape([-1, n_steps, n_inputs])
# 我在跑LSTM的数据格式可能跟你们不一样
纯手打无粘贴,如有错误请评论或联系我