最近发现了一个好用的类ImageDataGenerator
,可以使用它完成以下工作:
- Accepting a batch of images used for training.
- Taking this batch and applying a series of random transformations to each image in the batch (including random rotation, resizing, shearing, etc.).
- Replacing the original batch with the new, randomly transformed batch.
- Training the CNN on this randomly transformed batch (i.e., the original data itself is not used for training).
简单的说就是可以使用它读入一批图片,它会根据我们设置的属性值自动的进行图像增强(如旋转,水平翻转,截取等),方便我们克服过拟合,学习到更多的特征。
使用前我们需要对ImageDataGenerator
进行初始化:
#Updated to do image augmentation
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
- rotation_range is a value in degrees (0–180), a range within which to randomly rotate pictures.
- width_shift and height_shift are ranges (as a fraction of total width or height) within which to randomly translate pictures vertically or horizontally.
- shear_range is for randomly applying shearing transformations.
- zoom_range is for randomly zooming inside pictures.
- horizontal_flip is for randomly flipping half of the images horizontally. This is relevant when there are no assumptions of horizontal assymmetry (e.g. real-world pictures).
- fill_mode is the strategy used for filling in newly created pixels, which can appear after a rotation or a width/height shift.
关于ImageDataGenerator的更多属性可以查看keras文档
接下来就可以用ImageDataGenerator
读入图片了:
# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(
train_dir, # This is the source directory for training images
target_size=(150, 150), # All images will be resized to 150x150
batch_size=20,
# Since we use binary_crossentropy loss, we need binary labels
class_mode='binary')
history = model.fit(
train_generator,
steps_per_epoch=100, # 2000 images = batch_size * steps
epochs=100,
verbose=2)
使用ImageDataGenerator
的flow_from_directory
方法读入图片时有个非常“神奇”的一点,ImageDataGenerator
会自动帮我们的图片进行分类!这里的train_dir
的目录结构如下:
那么ImageDataGenerator
会自动帮我们将图片1,2,3.jpg分为cat
类,4,5,6分为dog
类。
target_size
参数会将读入图片转为指定大小,我们这里是resize成150*150像素大小。
然后我们训练时就直接传train_generator
即可,连y
标签都不用传,非常方便。