自定义DataGenerator
- 生成器,结合for循环以及yield来产生数据
import numpy as np
class DataGenerator(object):
def __init__(self, batch_size):
self.batch_size = batch_size
def generate(self, xs, ys):
x = xs[0]
y = ys[0]
batch_size = self.batch_size
n_samples = len(x)
index = np.arange(n_samples)
np.random.shuffle(index)
max_iter = np.ceil(n_samples / batch_size)
iter = 0
pointer = 0
while True:
if iter >= max_iter:
break
batch_idx = index[pointer: min(pointer + batch_size, n_samples)]
pointer += batch_size
yield x