一般在使用 Keras 进行图像方面的神经网络的训练的时候都会使用图片生成器 ImageDataGeneraor,它不仅仅使用方便,更支持实时数据增强。在使用 CPU 进行数据增强的同时,使用 GPU 训练模型,大大加快了模型训练的速度。
Keras 的《官方文档》 和早期网友翻译的《文档》都对它的使用进行了较为详细的解释,可能由于篇幅的限制,这些文档对于一些较为细节的问题没有过多的介绍,在这里我在前人的基础上进行一些补充。注意本文主要参照上面两个文档的内容而成,我自己单独补充的部分进行了黑体加粗。
1. ImageDataGenerator 类
下面是关于ImageDataGenerator 类的参数说明
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-06,
rotation_range=0,
width_shift_range=0.0,
height_shift_range=0.0,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='nearest',
cval=0.0,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0,
dtype=None)
这个类用以生成一个batch的图像数据,支持实时数据提升。训练时该函数会无限生成数据,直到达到规定的 epoch 次数为止。
1.1 参数
- featurewise_center: 布尔值。将输入数据的均值设置为 0,逐特征进行。这里所谓的逐特征进行,实际上就是逐通道进行,单通道图像减去的均值的是 [[[average]]],三通道图像减去的均值是 [[[R_average, G_average, B_average]]]
- samplewise_center: 布尔值。将每个样本的均值设置为 0。
- featurewise_std_normalization: 布尔值。将输入除以数据标准差,逐特征进行。
- samplewise_std_normalization: 布尔值。将每个输入除以其标准差。
- zca_epsilon: ZCA 白化的 epsilon 值,默认为 1e-6。
- zca_whitening: 布尔值。是否应用 ZCA 白化。
- rotation_range: 整数。随机旋转的度数范围。
- width_shift_range: 浮点数、一维数组或整数。随机水平移动的幅度范围。
float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
1-D 数组: 数组中的随机元素。
int: 来自间隔 (-width_shift_range, +width_shift_range) 之间的整数个像素。
width_shift_range=2 时,可能值是整数 [-1, 0, +1],与 width_shift_range=[-1, 0, +1] 相同;而 width_shift_range=1.0 时,可能值是 [-1.0, +1.0) 之间的浮点数。 - height_shift_range: 浮点数、一维数组或整数。随机水平移动的幅度范围。
float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
1-D array-like: 数组中的随机元素。
int: 来自间隔 (-height_shift_range, +height_shift_range) 之间的整数个像素。
height_shift_range=2 时,可能值是整数 [-1, 0, +1],与 height_shift_range=[-1, 0, +1] 相同;而 height_shift_range=1.0 时,可能值是 [-1.0, +1.0) 之间的浮点数。 - shear_range: 浮点数。剪切强度(以弧度逆时针方向剪切角度)。
补充:什么是剪切变换?
剪切变换是仿射变换的一种原始变换,指的是类似四边形不稳定性的那种性质,方形变四边形,任意一边均可被拉长的过程。 - zoom_range: 浮点数 或 [lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。
- channel_shift_range: 浮点数。随机通道转换的范围,或者可以理解为随机通道偏移的幅度。
- fill_mode: {“constant”, “nearest”, “reflect” or “wrap”} 之一。常用于填充由于旋转平移造成的图像空白,默认为 ‘nearest’。输入边界以外的点根据给定的模式填充:
‘constant’: kkkkkkkk|abcd|kkkkkkkk (cval=k)
‘nearest’: aaaaaaaa|abcd|dddddddd
‘reflect’: abcddcba|abcd|dcbaabcd
‘wrap’: abcdabcd|abcd|abcdabcd - cval: 浮点数或整数。用于边界之外的点的值,当 fill_mode = “constant” 时。
- horizontal_flip: 布尔值。随机水平翻转。
- vertical_flip: 布尔值。随机垂直翻转。
- rescale: 重缩放因子。默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之前&#