- fit和fit_generator的区别
首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练。
下面是fit传参的例子:
history = model.fit(x_train, y_train, epochs=10,batch_size=32,
validation_split=0.2)
这里需要给出epochs和batch_size,epoch是这个数据集要被轮多少次,batch_size是指这个数据集被分成多少个batch进行处理。
最后可以给出交叉验证集的大小,这里的0.2是指在训练集上占比20%。
fit_generator函数必须传入一个生成器,我们的训练数据也是通过生成器产生的,下面给出一个简单的生成器函数:
batch_size = 128
def generator():
while 1:
row = np.random.randint(0,len(x_train),size=batch_size)
x = np.zeros((batch_size,x_train.shape[-1]))
y = np.zeros((batch_size,))
x = x_train[row]
y = y_train[row]
yield x,y