a = np.arange(100)
defbatch_gen(data):# 定义batch数据生成器
idx = 0whileTrue:
if idx+10>100:
idx=0
start = idx
idx += 10yield data[start:start+10]
gen = batch_gen(a)
for i in range(20):
b = next(gen) # 在循环中利用next()函数调用batch数据
print(b)
defshuffle_aligned_list(data):"""Shuffle arrays in a list by shuffling each array identically."""
num = data[0].shape[0]
p = np.random.permutation(num)
return [d[p] for d in data]
defbatch_generator(data, batch_size, shuffle=True):"""Generate batches of data.
Given a list of array-like objects, generate batches of a given
size by yielding a list of array-like objects corresponding to the
same slice of each input.
"""if shuffle:
data = shuffle_aligned_list(data)
batch_count = 0whileTrue:
if batch_count * batch_size + batch_size > len(data[0]):
batch_count = 0if shuffle:
data = shuffle_aligned_list(data)
start = batch_count * batch_size
end = start + batch_size
batch_count += 1yield [d[start:end] for d in data]
实例化批量数据生成器
# make datax = np.arange(1,51).reshape((10,5))
y = np.arange(1,11).reshape((-1,1))
# define epochs and batch_size
batch_size = 5
epochs_num = (len(x)//batch_size)*epochs
# define generator
batch_gen = batch_generator([x,y],batch_size) # assign batch_gen as batch generator
利用next()循环语句中调用gen
for i in range(20):
batch_x,batch_y = next(batch_gen)
ifint(i%2)==0:
print('Epoch %d'%int(i/2))
print('The %d Batch_y:\n '%(i+1),batch_y.reshape((1,-1)),'\n')
Epoch 1
The 1 Batch_y:
[[ 5 2 7 10 4]]
The 2 Batch_y:
[[1 3 6 9 8]]
Epoch 2
The 3 Batch_y:
[[2 3 6 8 1]]
The 4 Batch_y:
[[ 5 9 10 4 7]]
Epoch 3
The 5 Batch_y:
[[ 2 10 1 3 8]]
The 6 Batch_y:
[[7 9 6 5 4]]
Epoch 4
The 7 Batch_y:
[[2 4 5 3 6]]
The 8 Batch_y:
[[ 1 7 9 8 10]]
Epoch 5
The 9 Batch_y:
[[10 4 7 5 2]]
The 10 Batch_y:
[[1 8 9 6 3]]
Epoch 6
The 11 Batch_y:
[[6 3 8 9 7]]
The 12 Batch_y:
[[ 1 10 2 4 5]]
Epoch 7
The 13 Batch_y:
[[1 3 4 5 8]]
The 14 Batch_y:
[[ 9 6 7 10 2]]
Epoch 8
The 15 Batch_y:
[[ 7 10 5 2 9]]
The 16 Batch_y:
[[6 4 3 1 8]]
Epoch 9
The 17 Batch_y:
[[ 3 10 2 4 8]]
The 18 Batch_y:
[[5 6 1 7 9]]
Epoch 10
The 19 Batch_y:
[[ 1 4 7 10 3]]
The 20 Batch_y:
[[8 9 6 5 2]]