07_tf.keras trick之使用图像数据增强
1. tf.keras.preprocessing.image模块
在tensorflow.keras.preprocessing.image
模块下有一系列的针对图像数据的预处理增强方法,同级别下还有序列sequence
以及文本text
的方法,这里主要记录下应用于图像的数据增强。
具体的方法列表可以在下面这张图看到,该图截自:https://tensorflow.google.cn/api_docs/python/tf。
2.tf.keras.preprocessing.image.ImageDataGenerator类
本部分可以参考 https://keras-zh.readthedocs.io/preprocessing/image/。
tensorflow.keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-06,
rotation_range=0,
width_shift_range=0.0,
height_shift_range=0.0,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='nearest',
cval=0.0,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0,
dtype=None)
该类通过实时数据增强生成张量图像数据批次,其中的参数用于定义图像增强的方法,主要参数如下:
rotation_range
是一个0~180的度数,用来指定随机选择图片的角度。width_shift
和height_shift
用来指定水平和竖直方向随机移动的程度,这是两个0~1之间的比rescale
值将在执行其他处理前乘到整个图像上,我们的图像在RGB通道都是0 ~ 255的整数,这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。shear_range
是用来进行错切变换的程度,参考错切变换zoom_range
用来进行随机的放大horizontal_flip
随机的对图片进行水平翻转,这个参数适用于水平翻转不影响图片语义的时候fill_mode
用来指定当需要进行像素填充,如旋转,水平和竖直位移时,如何填充新出现的像素
2.1 .flow方法
.flow
方法可以采集数据和标签数组,一般是已经读入内存的图像数据,利用其生成批量增强数据。
方法原型:
flow(x, y=None, batch_size=32, shuffle=True,
sample_weight=None, seed=None, save_to_dir=None,
save_prefix='', save_format='png', subset=None)
主要参数解释:
x
: 输入数据。秩为4的Numpy 矩阵或元组。如果是元组,第一个元素应该包含图像,第二个元素是另一个 Numpy 数组或一列 Numpy 数组,它们不经过任何修改就传递给输出。可用于将模型杂项数据与图像一起输入。对于灰度数据,图像数组的通道轴的值应该为 1,而对于 RGB 数据,其值应该为3。y
: 标签。batch_size
: 整数,默认为 32。shuffle
: 布尔值,默认为 True。sample_weight
: 样本权重。seed
: 整数,默认为None。save_to_dir
: 默认为 None,如果为字符串则表示指定要保存的正在生成的增强图片的目录,用于可视化正在执行的操作。save_prefix
: 字符串,默认为空字符串,。保存图片的文件名前缀,仅当save_to_dir设置时可用。save_format
: “png"或 “jpeg” ,仅当 save_to_dir 设置时可用,默认为"png”。
该方法可以返回一个生成元组 (x, y) 的迭代器,其中x是图像数据的Numpy数组(单张图像输入时),或 Numpy 数组列表(多个输入时),y是对应的标签的Numpy数组。如果’sample_weight’不是None,生成的元组形式为(x, y, sample_weight)。如果 y 是None, 只有Numpy数组x被返回。
例如下面的代码,这段代码利用.flow
方法由一张图片生成增强处理后的20张图片:
import numpy as np
from tensorflow.keras.preprocessing.image import *
# 定义ImageDataGenerator类
datagen = ImageDataGenerator(
rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,rescale = 1/255,
shear_range = 20,zoom_range = 0.2,horizontal_flip = True,fill_mode = 'nearest')
# 读入一张图片至计算机内存
img = load_img('sequential_model/data/cat_vs_dog/train/cat/cat.1.jpg')
# 转换成数组格式
x = img_to_array(img)
x = np.expand_dims(x,0)
# 生成20张图片
i = 0
for batch in datagen.flow(x, batch_size=1, save_to_dir='sequential_model/data/generate_images', save_prefix='new_cat', save_format='jpeg'):
i += 1
if i==20:
break
print('image generate finished!')
最终结果:
本部分完整代码:
06_image_data_enhancement.py
2.2 .flow_from_directory方法
.flow_from_directory
方法可以从指定路径中读取图片数据来生成指定尺寸的增强后的图像数据。
方法原型:
flow_from_directory(directory, target_size=(256, 256), color_mode='rgb',
classes=None, class_mode='categorical', batch_size=32,
shuffle=True, seed=None, save_to_dir=None, save_prefix='',
save_format='png', follow_links=False, subset=None,
interpolation='nearest')
主要参数解释:
directory
: 目标目录路径,一般每个类单独放在一个子目录中,任何在子目录树下的 PNG, JPG, BMP, PPM ,TIF 图像都将被包含在生成器中。target_size
: 整数元组 (height, width),默认(256, 256),所有的图像将被调整到的尺寸。color_mode
: “grayscale”, “rbg” 之一。默认:“rgb”,图像被转换成1或3个颜色通道。classes
: 可选的类的子目录列表(例如 [‘dogs’, ‘cats’])。默认:None。如果未提供,类的列表将自动从 directory 下的 子目录名称/结构 中推断出来,其中每个子目录都将被作为不同的类(类名将按字典序映射到标签的索引)。包含从类名到类索引的映射的字典可以通过 class_indices 属性获得。class_mode
: “categorical”, “binary”, “sparse”, “input” 或 None 之一。默认:“categorical”。决定返回的标签数组的类型:-
- “categorical” 将是 2D one-hot 编码标签,
-
- “binary” 将是 1D 二进制标签,“sparse” 将是 1D 整数标签,
-
- “input” 将是与输入图像相同的图像(主要用于自动编码器)。
-
- 如果为 None,不返回标签(生成器将只产生批量的图像数据,对于 model.predict_generator(), model.evaluate_generator() 等很有用)。请注意,如果 class_mode 为 None,那么数据仍然需要驻留在 directory 的子目录中才能正常工作。
shuffle
: 是否混洗数据(默认 True)。seed
: 可选随机种子,用于混洗和转换。follow_links
: 是否跟踪类子目录中的符号链接(默认为 False)。interpolation
: 在目标大小与加载图像的大小不同时,用于重新采样图像的插值方法。 支持的方法有 “nearest”, “bilinear”, and “bicubic”。 如果安装了 1.1.3 以上版本的 PIL 的话,同样支持 “lanczos”。 如果安装了 3.4.0 以上版本的 PIL 的话,同样支持 “box” 和 “hamming”。 默认情况下,使用 “nearest”。