#类:train, test都需要有rescale
train_datagen = ImageDataGenerator(
rotation_range=15,
rescale=1./255,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1
)
train_generator = train_datagen.flow_from_dataframe(
train_df,
"../working/train/",
x_col='filename',
y_col='category',
target_size=IMAGE_SIZE,
class_mode='categorical',
batch_size=batch_size
)
#相当于一个迭代器,可以不断生成新的不同的图片
#Pandas dataframe containing the filepaths relative to directory
#string, column in dataframe that contains the filenames (or absolute paths if directory is None).
#string or list, column/s in dataframe that has the target data.
example_df = train_df.sample(n=1).reset_index(drop=True)
example_generator = train_datagen.flow_from_dataframe(
example_df,
"../working/train/",
x_col='filename',
y_col='category',
target_size=IMAGE_SIZE,
class_mode='categorical'
)
plt.figure(figsize=(12, 12))
#生成15张图片(生成的数量根据自己的需要决定)
for i in range(0, 15):
#subplot: 在i+1的位置上画图
plt.subplot(5, 3, i+1)
#generator: 无限的以一个batch为单位生成数据,所以需要取第一个
#图片,并且在生成第一张图片后break
for X_batch, Y_batch in example_generator:
image = X_batch[0]
plt.imshow(image)
break
plt.tight_layout()
plt.show()
total_train = train_df.shape[0]
total_validate = validate_df.shape[0]
epochs=3 if FAST_RUN else 50
#由于generator会无限生成数据,所以需要根据steps_per_epoch的值
#决定这个epoch何时结束,何时开始新的epoch
history = model.fit_generator(
train_generator,
epochs=epochs,
validation_data=validation_generator,
validation_steps=total_validate//batch_size,
steps_per_epoch=total_train//batch_size,
callbacks=callbacks
)
predict = model.predict_generator(test_generator, steps=np.ceil(nb_samples/batch_size))
tensorflow ImageDataGenerator的使用
最新推荐文章于 2022-09-27 21:34:41 发布