fit.generator处理数据增强输入必须是生成器类型
昨天提到的fit_generator处理非常大的数据集的时候比fit好用,所以昨天就直接把fit用的数据集直接放在generator下面跑,结果一直报错,后来问了师兄后,发现是输入数据类型不对,fit_generator函数传入的类型必须是生成器类型,也就是经过生成器产生的。
例:定义一个生成器
batch_size = 128
def generator():
while 1:
row = np.random.randint(0,len(x_train),size=batch_size)
x = np.zeros((batch_size,x_train.shape[-1]))
y = np.zeros((batch_size,))
x = x_train[row]
y = y_train[row]
yield x,y