代码
import numpy as np
def batch_gen(data): # 定义batch数据生成器1
idx = 0
while True:
if idx + 10 > 100:
idx = 0
start = idx
idx += 10
yield data[start:start + 10]
def batch_generator(data, batch_size): # 批数据生成2
size = data.shape[0]
data_copy = data.copy()
indices = np.arange(size)
np.random.shuffle(indices)
data_copy = data_copy[indices]
idx = 0
while True:
if idx + batch_size <= size:
yield data_copy[idx:idx + batch_size]
idx += batch_size
else:
idx = 0
indices = np.arange(size)
np.random.shuffle(indices)
data_copy = data_copy[indices]
continue
if __name__ == '__main__':
data = np.arange(100)
gen = batch_gen(data)
# 结果 1
for i in range(20):
batch = next(gen) # 在循环中利用next()函数调用batch数据
print(batch)
# print(data)
# print(data.shape, data.shape)
# data_copy = data.copy()
# print(data_copy)
# size = data.shape[0]
# indices = np.arange(size)
# print(indices)
# np.random.shuffle(indices) # 把数据打乱
# data_copy = data_copy[indices]
# print(data_copy)
# 结果2
gen = batch_generator(data, batch_size=10)
for j in range(20):
batch = next(gen)
print(batch)
结果 1:
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]
[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]
[80 81 82 83 84 85 86 87 88 89]
[90 91 92 93 94 95 96 97 98 99]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]
[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]
[80 81 82 83 84 85 86 87 88 89]
[90 91 92 93 94 95 96 97 98 99]
[Finished in 0.2s]
结果2:
[11 33 96 21 61 60 80 58 75 10 40 15 2 27 84 17 29 94 72 39 7 47 78 31
83 66 97 88 43 3]
[48 19 52 79 86 92 54 44 32 9 46 18 62 4 55 81 87 1 25 59 45 50 6 57
12 73 67 37 38 82]
[56 36 49 34 30 68 99 69 77 98 85 26 5 51 24 90 14 42 76 28 20 35 91 65
22 13 0 89 71 53]
[83 1 29 0 17 85 54 22 41 78 92 68 37 80 71 57 72 67 90 62 12 66 52 95
43 65 86 21 39 38]
[69 28 18 49 34 56 50 91 27 3 76 35 60 82 19 11 45 99 44 96 81 46 40 88
48 24 94 97 16 93]
[77 55 4 20 32 70 6 74 79 87 36 15 58 61 64 13 9 26 23 84 25 5 73 10
53 31 98 2 33 59]
[53 8 99 79 86 6 36 59 18 71 69 60 77 58 61 48 98 15 82 89 1 16 37 28
74 78 90 2 14 62]
[55 80 87 42 85 88 70 29 39 5 52 24 68 32 83 57 45 73 93 38 7 4 75 44
34 72 41 10 13 12]
[51 96 25 81 54 46 11 91 76 20 84 33 66 23 92 63 35 43 65 17 22 9 49 95
47 26 64 27 94 0]
[46 44 96 5 62 13 16 30 82 84 91 31 21 32 22 69 43 37 56 86 73 35 38 67
52 18 90 41 0 68]
[ 8 27 51 12 7 79 98 53 85 20 80 97 92 50 88 59 10 1 57 58 54 65 19 2
95 34 89 70 33 83]
[29 17 61 36 99 39 23 75 78 45 72 47 87 40 77 4 24 6 25 81 66 3 93 26
71 11 76 15 94 60]
[44 40 11 95 61 8 0 70 18 37 62 53 78 49 76 73 87 27 6 90 35 92 47 30
34 84 93 72 75 26]
[17 10 42 21 57 82 22 71 65 89 66 4 99 77 74 54 41 83 2 38 3 88 13 28
97 9 32 96 52 24]
[20 94 51 25 50 39 12 29 16 1 5 56 19 33 91 67 86 7 23 69 45 81 80 43
15 68 55 63 64 36]
[ 2 77 27 32 72 57 5 0 93 33 25 8 86 81 36 59 92 63 97 49 30 20 52 16
42 18 26 6 51 83]
[17 53 3 94 50 95 12 78 11 96 55 37 70 15 39 89 23 61 43 45 31 9 24 80
99 54 76 10 69 98]
[ 7 65 14 85 79 64 38 1 35 87 56 68 46 62 22 19 75 60 40 84 29 90 13 91
28 4 21 44 71 67]
[14 0 12 60 7 38 42 80 70 28 65 11 49 78 32 77 90 5 10 91 71 87 61 53
76 72 29 40 2 68]
[56 64 44 48 27 20 34 94 92 26 82 41 13 19 85 8 47 45 52 98 79 89 96 66
63 4 58 3 55 30]