生成
def gen_data(xs, ys, batch_size=100):
image_buf = list()
label_buf = list()
while True:
for image, label in zip(xs, ys):
if len(image_buf) < batch_size:
image_buf.append(cv2.cvtColor(cv2.resize(np.reshape(image,[28,28]),(224,224)),cv2.COLOR_GRAY2BGR))
label_buf.append(label)
else:
yield np.array(image_buf), np.array(label_buf)
image_buf, label_buf = list(), list()
gen = gen_data(mnist.train.images, mnist.train.labels)
调用
for i in range(num_batches):
xs,ys=next(gen)
end=start+batch_size
end=min(end,max_data_len)
xx=xs[start:end]
yy=ys[start:end]