keras.src.legacy.preprocessing.image.ImageDataGenerator

keras.src.legacy.preprocessing.image.ImageDataGenerator 是 Keras 中用于图像数据增强和批量生成训练数据的工具类,属于 Keras 旧版本的图像处理模块。它支持实时数据增强(如旋转、翻转、缩放等),适用于图像分类、目标检测等计算机视觉任务。

主要功能

  1. 数据增强:通过随机变换生成多样化的训练样本,提高模型的泛化能力。
  2. 批量生成:从目录或数组中批量加载图像数据,减少内存占用。
  3. 归一化:对图像像素值进行标准化处理(如缩放至 [0,1] 或归一化到均值为 0)。

核心参数

datagen = ImageDataGenerator(
    rotation_range=20,        # 随机旋转角度范围
    width_shift_range=0.2,    # 水平平移范围
    height_shift_range=0.2,   # 垂直平移范围
    shear_range=0.2,          # 剪切强度
    zoom_range=0.2,           # 缩放范围
    horizontal_flip=True,     # 随机水平翻转
    vertical_flip=False,      # 随机垂直翻转
    rescale=1./255,           # 像素值缩放因子
    preprocessing_function=None,  # 自定义预处理函数
    validation_split=0.2      # 训练/验证集分割比例
)

常用方法

  1. flow_from_directory()
    从指定目录加载图像并生成批量数据:

    train_generator = datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary'  # 分类模式:'binary'、'categorical' 等
    )
    
  2. flow()
    从 Numpy 数组生成批量数据:

    datagen.flow(X_train, y_train, batch_size=32)
    
  3. fit()
    计算数据的统计信息(如均值、标准差),用于归一化:

    datagen.fit(X_train)
    
  4. save_to_dir()

    将增强后的图像保存到指定目录(调试用):

    train_generator = datagen.flow_from_directory(
        'data/train',
        save_to_dir='augmented_images'
    )
    

  5. 示例

    训练模型时使用 ImageDataGenerator

    from keras.src.legacy.preprocessing.image import ImageDataGenerator
    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
    
    # 创建数据生成器
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True
    )
    
    test_datagen = ImageDataGenerator(rescale=1./255)  # 验证集不增强
    
    # 加载数据
    train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary'
    )
    
    validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary'
    )
    
    # 构建模型
    model = Sequential([
        Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)),
        MaxPooling2D((2,2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    
    # 编译和训练
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.fit(
        train_generator,
        steps_per_epoch=train_generator.samples // 32,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // 32,
        epochs=10
    )
    

注意事项 

  1. 版本兼容性
    该类属于 Keras 旧版本(keras.src.legacy)。在 TensorFlow 2.0 及以后版本中,推荐使用 tf.keras.preprocessing.image.ImageDataGenerator 或更现代的 tf.data.Dataset API。

  2. 性能考虑
    实时数据增强会增加训练时间。对于大型数据集,建议使用 tf.data.Dataset 配合 GPU 加速。

  3. 自定义预处理
    通过 preprocessing_function 参数可传入自定义处理函数(如 OpenCV 操作)。

  4. 内存管理
    使用 flow_from_directory() 时,确保图像路径正确,且内存足够存储批量数据。

替代方案

  • tf.keras.preprocessing.image_dataset_from_directory
    更高效的数据集加载方式,返回 tf.data.Dataset 对象。
  • tf.keras.layers.RandomFliptf.keras.layers.RandomRotation
    层内数据增强(直接嵌入模型,仅在推理时应用)。

如果你使用的是 TensorFlow 2.x,建议优先选择现代 API,以获得更好的性能和兼容性。

      评论 1
      添加红包

      请填写红包祝福语或标题

      红包个数最小为10个

      红包金额最低5元

      当前余额3.43前往充值 >
      需支付:10.00
      成就一亿技术人!
      领取后你会自动成为博主和红包主的粉丝 规则
      hope_wisdom
      发出的红包
      实付
      使用余额支付
      点击重新获取
      扫码支付
      钱包余额 0

      抵扣说明:

      1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
      2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

      余额充值