Tensorflow2.0中数据增强及训练过程中数据增强

数据增强:它是正则化的一种形式,使我们的网络可以更好地将其推广到我们的测试/验证集。

ImageDataGenerator工作原理:

ImageDataGenerator接受原始数据,对其进行随机转换,并仅返回转换后的新数据。接受一批用于训练的图像;
进行此批处理并对批处理中的每个图像应用一系列随机变换(包括随机旋转,调整大小,剪切等);
用新的,随机转换的批次替换原始批次;
在此随机转换的批次上训练CNN(即原始数据本身不用于训练)。

1)对图片进行数据增强,并将其结果保存到文件夹中:

""
一: 定义ImageDataGenerator 图片生成器
""
from tensorflow.keras.preprocessing.image import ImageDataGenerator


""
二: 封装flow_from_directory()
其中:
path:文件读入的路径,必须是子文件夹的上一级(这里是个坑,不过试一哈就懂了)
target_size:图片resize成的尺寸,不设置会默认设置为(256.256)
batch_size:每次输入的图片的数量,例如batch_size=32,一次进行增强的数量为32,
个人经验:batch_size的大小最好是应该和文件的数量是可以整除的关系
save_to_dir:增强后图片的保存位置
save_prefix:文件名加前缀,方便查看
save_format:保存图片的数据格式
产生的图片总数:batch_size*6(即range中的数字)
""
gen = datagen.flow_from_directory(
                           path,
                           target_size=(224, 224),
                           batch_size=15,
                           save_to_dir=dst_path,
                           save_prefix='xx',
                           save_format='jpg')

""
三: 调用gen.next()执行增强过程
""
for i in range(6):
    gen.next()

2)训练过程中数据增强

from tensorflow.keras.preprocessing.image import ImageDataGenerator

""
一: 定义ImageDataGenerator
""
datagen = ImageDataGenerator(
        # 布尔值,使输入数据集去中心化(均值为0), 按feature执行
        featurewise_center=False,

        # 布尔值,使输入数据的每个样本均值为0
        samplewise_center=False,

        # 布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行
        featurewise_std_normalization=False,

        # 布尔值,将输入的每个样本除以其自身的标准差
        samplewise_std_normalization=False,

        # 布尔值,对输入数据施加ZCA白化
        zca_whitening=False,

        # ZCA使用的eposilon,默认1e-6
        zca_epsilon=1e-06,

        # 整数,数据提升时图片随机转动的角度 (deg 0 to 180)
        rotation_range=0,

        # 浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
        width_shift_range=0.1,

        # 浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
        height_shift_range=0.1,

        # 浮点数,剪切强度(逆时针方向的剪切变换角度)
        shear_range=0.,

        # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] 
        #  = [1 - zoom_range, 1+zoom_range]
        zoom_range=0.,

        # 浮点数,随机通道偏移的幅度
        channel_shift_range=0.,

        # ‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数 
        # 给定的方法进行处理
        fill_mode='nearest',

        # 浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值
        cval=0.,

        # 布尔值,是否进行随机水平翻转
        horizontal_flip=True,

        # 布尔值,是否进行随机竖直翻转
        vertical_flip=False,

        # 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其 
        # 他变换之前)
        rescale=None,

        # 将被应用于每个输入的函数。该函数将在图片缩放和数据提升之后运行。该函数接受一个参数, 
        # 为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array
        preprocessing_function=None,

        # 字符串,“channel_first”或“channel_last”之一,代表图像的通道维的位置。
        data_format=None,

        # 验证集切分比重 (strictly between 0 and 1)
        validation_split=0.0)

""
二:fit中调用
1):fit ---->fit_generator
2):传入数据集变为datagen.flow(x_train, y_train, batch_size=batch_size)
""
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    steps_per_epoch=len(x_train) // batch_size,
                    epochs=epochs, verbose=1, workers=4,
                    callbacks=callbacks,
                    use_multiprocessing=False)

fit 中的 verbose:

verbose:日志显示
verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录
verbose = 2 为每个epoch输出一行记录
注意: 默认为 1

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值