官方文本解释:https://tensorflow.google.cn/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=0.,
width_shift_range=0.,
height_shift_range=0.,
shear_range=0.,
zoom_range=0.,
channel_shift_range=0.,
fill_mode='nearest',
cval=0.,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=K.image_data_format())
用以生成一个batch的图像数据,支持实时数据提升。训练时该函数会无限生成数据,直到达到规定的epoch次数为止。
featurewise_center:布尔值,使输入数据集去中心化(均值为0), 按feature执行
samplewise_center:布尔值,使输入数据的每个样本均值为0
featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行
samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差
zca_whitening:布尔值,对输入数据施加ZCA白化
zca_epsilon: ZCA使用的eposilon,默认1e-6
rotation_range:整数,数据提升时图片随机转动的角度
width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
channel_shift_range:浮点数,随机通道偏移的幅度
fill_mode:;‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理
cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值
horizontal_flip:布尔值,进行随机水平翻转
vertical_flip:布尔值,进行随机竖直翻转
rescale: 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)
preprocessing_function: 将被应用于每个输入的函数。该函数将在图片缩放和数据提升之后运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array
data_format:字符串,“channel_first”或“channel_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channel_last”对应原本的“tf”,“channel_first”对应原本的“th”。以128x128的RGB图像为例,“channel_first”应将数据组织为(3,128,128),而“channel_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channel_last”
实验代码:
# 综合增强示例
from tensorflow import keras
from numpy import expand_dims
from matplotlib import pyplot
# 读入图片
img = tf.keras.preprocessing.image.load_img(img_path)
# 转换为 numpy 数组
data = tf.keras.preprocessing.image.img_to_array(img)
# 扩展维度
samples = expand_dims(data, 0)
print(samples.shape)
# 生成数据增强迭工厂
datagen = tf.keras.preprocessing.image.ImageDataGenerator(shear_range=20,
horizontal_flip=True,
channel_shift_range=100,
zoom_range=[0.5,2],
fill_mode = 'nearest')
# 准备迭代器
it = datagen.flow(samples, batch_size=1)
# 生成数据并画图
for i in range(9):
# 定义子图
pyplot.subplot(330 + 1 + i)
# 生成一个批次图片
batch = it.next()
# 转换为无符号整型方便显示
image = batch[0].astype('uint32')
# 画图
pyplot.imshow(image)
# 展示图片
pyplot.show()
参考连接:
https://blog.csdn.net/weixin_43917589/article/details/109771019
https://blog.csdn.net/xjq_ncu/article/details/80260735?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control&dist_request_id=d570f3d4-b12b-4d2f-b798-23a71a39d2b3&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control
https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/78271479?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.control&dist_request_id=d570f3d4-b12b-4d2f-b798-23a71a39d2b3&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.control