keras.src.legacy.preprocessing.image.ImageDataGenerator
是 Keras 中用于图像数据增强和批量生成训练数据的工具类,属于 Keras 旧版本的图像处理模块。它支持实时数据增强(如旋转、翻转、缩放等),适用于图像分类、目标检测等计算机视觉任务。
主要功能
- 数据增强:通过随机变换生成多样化的训练样本,提高模型的泛化能力。
- 批量生成:从目录或数组中批量加载图像数据,减少内存占用。
- 归一化:对图像像素值进行标准化处理(如缩放至 [0,1] 或归一化到均值为 0)。
核心参数
datagen = ImageDataGenerator(
rotation_range=20, # 随机旋转角度范围
width_shift_range=0.2, # 水平平移范围
height_shift_range=0.2, # 垂直平移范围
shear_range=0.2, # 剪切强度
zoom_range=0.2, # 缩放范围
horizontal_flip=True, # 随机水平翻转
vertical_flip=False, # 随机垂直翻转
rescale=1./255, # 像素值缩放因子
preprocessing_function=None, # 自定义预处理函数
validation_split=0.2 # 训练/验证集分割比例
)
常用方法
-
flow_from_directory()
从指定目录加载图像并生成批量数据:train_generator = datagen.flow_from_directory( 'data/train', target_size=(150, 150), batch_size=32, class_mode='binary' # 分类模式:'binary'、'categorical' 等 )
-
flow()
从 Numpy 数组生成批量数据:datagen.flow(X_train, y_train, batch_size=32)
-
fit()
计算数据的统计信息(如均值、标准差),用于归一化:datagen.fit(X_train)
-
save_to_dir()
将增强后的图像保存到指定目录(调试用):
train_generator = datagen.flow_from_directory( 'data/train', save_to_dir='augmented_images' )
-
示例
训练模型时使用 ImageDataGenerator
from keras.src.legacy.preprocessing.image import ImageDataGenerator from keras.models import Sequential from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense # 创建数据生成器 train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True ) test_datagen = ImageDataGenerator(rescale=1./255) # 验证集不增强 # 加载数据 train_generator = train_datagen.flow_from_directory( 'data/train', target_size=(150, 150), batch_size=32, class_mode='binary' ) validation_generator = test_datagen.flow_from_directory( 'data/validation', target_size=(150, 150), batch_size=32, class_mode='binary' ) # 构建模型 model = Sequential([ Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)), MaxPooling2D((2,2)), Flatten(), Dense(128, activation='relu'), Dense(1, activation='sigmoid') ]) # 编译和训练 model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) model.fit( train_generator, steps_per_epoch=train_generator.samples // 32, validation_data=validation_generator, validation_steps=validation_generator.samples // 32, epochs=10 )
注意事项
-
版本兼容性:
该类属于 Keras 旧版本(keras.src.legacy
)。在 TensorFlow 2.0 及以后版本中,推荐使用tf.keras.preprocessing.image.ImageDataGenerator
或更现代的tf.data.Dataset
API。 -
性能考虑:
实时数据增强会增加训练时间。对于大型数据集,建议使用tf.data.Dataset
配合 GPU 加速。 -
自定义预处理:
通过preprocessing_function
参数可传入自定义处理函数(如 OpenCV 操作)。 -
内存管理:
使用flow_from_directory()
时,确保图像路径正确,且内存足够存储批量数据。
替代方案
tf.keras.preprocessing.image_dataset_from_directory
更高效的数据集加载方式,返回tf.data.Dataset
对象。tf.keras.layers.RandomFlip
,tf.keras.layers.RandomRotation
层内数据增强(直接嵌入模型,仅在推理时应用)。
如果你使用的是 TensorFlow 2.x,建议优先选择现代 API,以获得更好的性能和兼容性。