Keras 图像预处理 ImageDataGenerator

本文详述Keras中ImageDataGenerator类的使用方法,包括数据增强、批量图像生成、图像标准化与中心化处理,以及如何利用fit()和standardize()函数对数据进行预处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • 本文主要介绍Keras中图像分类任务用到的图像预处理部分的内容。
  • 注意:并不是介绍Keras中所有的图像预处理函数。

1. 简介

使用Keras进行图像分类任务时,如果数据集较少(数据获取困难等),为了尽可能的充分利用有限数据的价值,可以进行数据增强处理。

通过一系列随机变换对数据进行提升,这样有利于抑制过拟合,提升模型的泛化能力。

Keras中提供了一个用于数据增强的类(Keras.preprocessing.image.ImageDataGenerator)来实现此功能。这个类可以:

  • 在训练过程中,设置要实施的随机变化
  • 通过.flow.flow_from_directory(directory)方法实例化一个针对图像batch的生成器,这些生成器可以被用做keras相关方法的输入,如fit_generator, evaluate_generatorpredict_generator

什么意思呢?——使用ImageDataGenerator类不仅可以在训练过程中进行图像的随机变化,增加训练数据;还附带赠送了获取数据batch生成器对象的功能,省去了手工再去获取batch数据的部分。

2. ImageDataGenerator类介绍

ImageDataGenerator类路径:keras/preprocessing/image.py

作用:通过实时数据增强生成批量图像数据向量。训练时该函数会无限循环生成数据,直到达到规定的epoch次数为止。

ImageDataGenerator继承于keras_preprocessing/image/image_data_generator.py中的ImageDataGenerator类。

# keras/preprocessing/image.py
class ImageDataGenerator(image.ImageDataGenerator):
    def __init__(self,
                 featurewise_center=False,
                 samplewise_center=False,
                 featurewise_std_normalization=False,
                 samplewise_std_normalization=False,
                 zca_whitening=False,
                 zca_epsilon=1e-6,
                 rotation_range=0,
                 width_shift_range=0.,
                 height_shift_range=0.,
                 brightness_range=None,
                 shear_range=0.,
                 zoom_range=0.,
                 channel_shift_range=0.,
                 fill_mode='nearest',
                 cval=0.,
                 horizontal_flip=False,
                 vertical_flip=False,
                 rescale=None,
                 preprocessing_function=None,
                 data_format=None,
                 validation_split=0.0,
                 dtype=None):

参数

  • featurewise_center:布尔值,使输入数据集去中心化(均值为0),逐特征进行。
  • samplewise_center:布尔值,使输入数据的每个样本均值为0
  • featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行
  • samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差
  • zca_whitening:布尔值,对输入数据施加ZCA白化
  • zca_epsilon: ZCA使用的eposilon,默认1e-6
  • rotation_range:整数,图片随机转动的角度范围
  • width_shift_range:浮点数,一维数组或整数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
    • float:如果<1,则除以总宽度的值,如果>=1,则为宽度像素值
    • 一维数组:数组中的随机元素
    • 整型:来自间隔(-width_shift_range,width_shift_range)之间的整数个像素
    • width_shift_range=2:可能值是整数[-1,0,1],与width_shift_range=[-1,0,1]相同,而当width_shfit_range=1.0时,可能值是半开区间[-1.0,1.0]之间的浮点数(后半句没有理解)。
  • height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度。具体含义与width_shift_range相同。
  • brightness_range:两个float组成的元组或列表。选择亮度值的范围
  • shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
  • zoom_range:浮点数或[lower, upper]。随机缩放范围,如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]
  • channel_shift_range:浮点数,随机通道转换的范围。
  • fill_mode{"constant", "nearest", "reflect" or "wrap"} 之一。默认为'nearest'。输入边界以外的点根据给定的模式填充:
    • 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
    • 'nearest': aaaaaaaa|abcd|dddddddd
    • 'reflect': abcddcba|abcd|dcbaabcd
    • 'wrap': abcdabcd|abcd|abcdabcd
  • cval: 浮点数或整数。用于边界之外的点的值,当fill_mode = "constant"时。
  • horizontal_flip: 布尔值,随机水平翻转。
  • vertical_flip: 布尔值,随机垂直翻转。
  • rescale: 重缩放因子。默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之后
  • preprocessing_function:该函数应用于每个输入上,在图像被resize和增强之后运行。该函数接收一个参数,一张图像(秩为3的numpy tensor),同样输出一个相同shapeNumpy tensor
  • data_format:图像数据格式,{"channels_first", "channels_last"} 之一。"channels_last" 模式表示图像输入尺寸应该为(samples, height, width, channels)"channels_first" 模式表示输入尺寸应该为(samples, channels, height, width)。默认为 在 Keras 配置文件~/.keras/keras.json中的image_data_format值。如果你从未设置它,那它就是"channels_last"
  • validation_split:浮点型。保留用于验证集的图像比例(严格在0,1之间)
  • dtype:生成数组使用的数据类型。

使用示例

from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

data_generator = datagen.flow_from_directory('./datas/train', target_size=(224,224), batch_size=32)

3. ImageDataGenerator类方法

该类的几个重要方法如下:

  • flow() : 该方法输入数据(Numpy或元组形式)和标签(可选),返回一个迭代器,格式是元组(x,y)(x)(x,y,sample_weight)。该方法还可以指定样本输出路径及前缀,格式,用于保存增强处理后的图像。
  • flow_from_directory(): 获取图像路径,生成批量增强数据。该方法只需指定数据所在的路径,而无需输入numpy形式的数据,也无需输入标签值,会自动返回对应的标签值。返回一个生成(x, y)元组的DirectoryIterator
  • flow_from_dataframe(): 输入数据为Pandas dataframe格式。返回生成(x, y) 元组的DataFrameIterator
注意事项
  • 主要区别是输入数据和输出数据的格式不同。
  • flow_from_directory()flow_from_dataframe()两个函数都将图像resize到指定大小。而flow()
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值