我们在使用model.fit()进行训练的时候, 在这之前你肯定会有训练集的x_img_train,y_label_train两个参数。
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)
但是当我们使用model.fit_generator()的时候,它的方法是这样的:
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
可以看到它要求传入的参数是一个generator.官网说的很清楚,(不清楚的可以看官网)这里的generator是一个生成器,主要是训练自己的数据,并且数据非常多的时候可以不用把数据全部加载进内存,而是用生成器自己一点点读取。大大提高的运行效率。
下面是这个生成器的生成方法:
#这是训练集的生成器
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
## 训练图片生成器
train_generator = train_datagen.flow_from_directory(
train_data_dir,#训练样本地址
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') #多分类
test_datagen = ImageDataGenerator(rescale=1. / 255)
##验证集的生成器
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,#验证样本地址
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical',
shuffle=False) #多分类
好了,有了这个train_generator生成器我们就可以入入fit_generator(...)里面进行训练了。
对了,这里说明下train_data_dir / validation_data_dir 是我本机的训练集与验证集的地址。
目录结构形似:
'''
data/train/
1/
001.jpg
002.jpg
...
2/
001.jpg
002.jpg
...
data/validation/
1/
001.jpg
002.jpg
...
2/
001.jpg
002.jpg
...
'''