1 tf.keras.preprocessing.image.ImageDataGenerator()
生成批量张量图像数据与实时数据增强
白化是一种重要的预处理过程,其目的就是降低输入数据的冗余性,使得经过白化处理的输入数据具有如下性质:(i)特征之间相关性较低;(ii)所有特征具有相同的方差。
tf.keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False,输入数据集的均值设置为0
samplewise_center=False,每个样本的均值设置为0
featurewise_std_normalization=False,根据数据集的标准差std对输入进行区分。
samplewise_std_normalization=False,将每个输入除以它的std。
zca_whitening=False,应用ZCA白化
zca_epsilon=1e-06,白化参数
rotation_range=0,随机旋转角度
width_shift_range=0.0,宽度方向平移
height_shift_range=0.0,高度方向平移
brightness_range=None,元组或两个浮动的列表。用于从中选择亮度值的范围。
shear_range=0.0,剪切强度(逆时针方向剪切角,单位为度),即为错切变换
zoom_range=0.0,随机缩放范围
channel_shift_range=0.0, 随机给某个通道的所有像素加上设置值
fill_mode='nearest', 填充方式
cval=0.0,当fill_mode='constant'时输入边界之外的点的填充值,默认0
horizontal_flip=False,随机水平翻转输入
vertical_flip=False,随机垂直翻转输入
rescale=None,图像像素归一化因子,一般为1/255,在完成所有转换后给数据×1/255
preprocessing_function=None,用户自定义函数进行图像增强,输入参数为图像,输出为numpy张量
data_format=None,数据格式默认channels_last,即(n,h,w,c)
validation_split=0.0,验证集的分割比例
dtype=None 生成数组的类型
)
2 .flow()
获取数据和标签数组,生成批量增强数据。
flow(
x, 输入数据
y=None, 对应标签
batch_size=32,
shuffle=True,默认打乱数据
sample_weight=None,
seed=None,随机数种子
save_to_dir=None,保存增强数据的地址
save_prefix='',用于保存的图片的文件名的前缀(
save_format='png',保存格式
ignore_class_split=False,忽略类数量的差异
subset=None如果在ImageDataGenerator中设置了validation_split,则是数据的子集("training"或"validation"))
3 示例
这里只展示如何应用ImageDataGenerator()和flow()
train_data=ImageDataGenerator(
rescale=1/255,
rotation_range=40,#旋转40度
width_shift_range=0.2, #宽度方向平移20%
height_shift_range=0.2,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True, #水平翻转
fill_mode='nearest' #填充方式最近邻
)
test_data=ImageDataGenerator(rescale=1/255)model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,optimizer=
tf.optimizers.RMSprop(learning_rate=0.001),metrics=['acc'])
#模型开始训练
start_time=datetime.datetime.now()
history=model.fit(train_data.flow(training_images,training_labels,batch_size=32),
steps_per_epoch=len(training_images)/32,#全部训练数据用一次所需要的epoch
epochs=15,#一共训练15轮
validation_data=test_data.flow(testing_images,testing_labels,batch_size=32),
validation_steps=len(testing_images)/32 #全部测试数据用一次所需要的epoch
)