Keras数据增强ImageDataGenerator

在实际训练自己的网络模型时,基本都会遇到数据不够的难题.

Keras.preprocessing.imgae.ImageDataGenerator图片生成器,可以批量生成数据,防止模型过拟合并提高泛化能力.

使用方法如下:

#coding:utf-8
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
 
#定义图片生成器
data_gen = ImageDataGenerator(rotation_range=40,
                              width_shift_range=0.2,
                              height_shift_range=0.2,
                              horizontal_flip=True,
                              vertical_flip=True,
                              fill_mode='nearest',
                              data_format='channels_last')
 
img=load_img('/home/zyx/Dataset/flower_photos/daisy/144603918_b9de002f60_m.jpg')
x = img_to_array(img,data_format="channels_last")   #图片转化成array类型,因flow()接收numpy数组为参数
x=x.reshape((1,) + x.shape)     #要求为4维
 
#使用for循环迭代,生成图片
i = 0
for batch in data_gen.flow(x,batch_size=1,
                           save_to_dir='/home/zyx/Dataset/flower_photos/dataGen',
                           save_prefix='flower',
                           save_format='jpeg'):
    print batch.shape
    i += 1
    if i>20:
        break
        
#使用next()迭代,生成图片
next(data_gen.flow(x,batch_size=2,
                   save_to_dir='/home/zyx/Dataset/flower_photos/dataGen',
                   save_prefix='next_gen',
                   save_format='jpeg'))
data_gen.flow()返回是数据类型是NumpyArrayIterator类型,需要使用for或next,才可得到其中数据.

我们已经知道,可以直接作用于for循环的数据类型有以下几种:

一类是集合数据类型,如list、tuple、dict、set、str等;

一类是generator,包括生成器和带yield的generator function。

而生成器不但可以作用于for循环,还可以被next()函数不断调用并返回下一个值,直到最后抛出StopIteration错误表示无法继续返回下一个值了。

可以被next()函数调用并不断返回下一个值的对象称为迭代器:Iterator。

可以使用isinstance()判断一个对象是否是Iterator对象:

但是,这样生成的数据,相似度很高,在原始样本少的情况下,模型依旧很容易过拟合.
 

ImageDataGeneratorKeras中一个非常方便的图像数据生成器,主要用于数据增强data augmentation)和实时数据扩充(real-time data augmentation)。它可以自动将一批原始图像转换为训练所需的随机数据,比如随机旋转、缩放、翻转等操作,从而扩大训练数据集,提高模型的泛化能力。 ImageDataGenerator类可以通过定义不同的参数来实现各种图像增强的操作,例如: - rotation_range:旋转角度范围; - width_shift_range、height_shift_range:图像水平、垂直方向的平移范围; - shear_range:剪切强度; - zoom_range:随机缩放范围; - horizontal_flip、vertical_flip:是否随机水平、垂直翻转图像; - fill_mode:填充模式。 使用ImageDataGenerator类时,需要先通过fit()方法计算出数据集的统计信息,然后可以通过flow()方法进行数据生成。例如: ```python from keras.preprocessing.image import ImageDataGenerator datagen = ImageDataGenerator( rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True) datagen.fit(x_train) model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), steps_per_epoch=len(x_train)/32, epochs=50) ``` 这段代码中,我们先定义了一个ImageDataGenerator对象,然后定义了各种图像增强的参数。接着,我们通过fit()方法计算出数据集的统计信息,最后通过flow()方法生成扩充后的数据集,用于训练模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值