和熊本熊一起学习图像增强方法

640?wx_fmt=jpeg

深度学习要取得较好的学习效果,通常对样本数量有一定的要求,在模型的研发过程中可以借助imagenet(具有1000多万张图片)等现成的大型数据集进行训练。但是在解决实际问题中,样本往往因为收集困难,缺乏历史数据等原因造成短缺,数量较少。

如何使用好手里有限的样本,进行充分利用,提升模型的泛化能力呢?除去模型及优化过程中的参数调节等原因,就样本本身,我们可以使用图像增强的方法。

一、什么是图像增强

简单的讲,图像增强就是利用已有图像通过一系列技术操作生成新样本的过程。常见的有翻转,旋转,平移,扭曲,或剪裁等。

使用Keras自带的图像增强模块—图像生成器(ImageDataGenerator)可帮我们快速实现上述过程。其有如下参数设置,其中前6个是对样本的预处理,若开启需对训练集和验证集同时使用,往后则为增强处理。

keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-6,
    rotation_range=0.,
    width_shift_range=0.,
    height_shift_range=0.,
    shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    cval=0.,
    horizontal_flip=False,
    vertical_flip=False,
    rescale=None,
    preprocessing_function=None,
    data_format=K.image_data_format())

下面将具体详解各个参数的含义、对应代码样例,并由我们的模特-可爱的熊本(图片来自网络)示范使用后的效果。

640?wx_fmt=png

二、实现与效果
旋转

rotation_range:随机转动的最大角度

datagen = ImageDataGenerator(rotation_range=90)
640?wx_fmt=png
平移
width_shift_range:随机水平平移的最大幅度
(按原图比例)
height_shift_range: 上下平移
datagen = ImageDataGenerator(width_shift_range=0.3)
640?wx_fmt=png
datagen = ImageDataGenerator(height_shift_range=0.3)
640?wx_fmt=png
随机聚焦
zoom_range:随机缩放的最大幅度,
可输入单个浮点数或形如[lower,upper]的列表。
datagen = ImageDataGenerator(zoom_range=0.5)

640?wx_fmt=png

当为列表时,大于1表示缩小

datagen = ImageDataGenerator(zoom_range=[1.5,1.5])

640?wx_fmt=png

小于1表示放大

datagen = ImageDataGenerator(zoom_range=[0.5,0.5])

640?wx_fmt=png

介于1上下时,表示随机,且缩放比例不均衡

datagen = ImageDataGenerator(zoom_range=[0.5,1.5])

640?wx_fmt=png

切变拉伸

shear_range:随机切变拉伸的最大角度
datagen = ImageDataGenerator(shear_range = 50)

640?wx_fmt=png

曝光度

channel_shift_range:通道随机偏移的最大幅度

(效果为曝光度)

datagen = ImageDataGenerator(channel_shift_range=100)

640?wx_fmt=png

翻转

horizontal_flip:随机水平翻转
vertical_flip:随机竖直翻转
datagen = ImageDataGenerator(horizontal_flip=True)

640?wx_fmt=png

datagen = ImageDataGenerator(vertical_flip=True)

640?wx_fmt=png

缺失填充
fill_mode:当进行变换时超出边界点的填充方案。
可选‘constant’,‘nearest’,‘reflect’或‘wrap’,
默认为‘nearest’。
datagen = ImageDataGenerator(zoom_range=[3,3],fill_mode='nearest')

640?wx_fmt=png

自定义操作
preprocessing_function: 执行预定义的函数操作。
该函数将在图片缩放和数据提升之后运行。
此处为自行编写的随机剪裁操作。
datagen = ImageDataGenerator(preprocessing_function=random_crop_image)

640?wx_fmt=png

混合-完整代码

import numpy as np
from keras.preprocessing.image import ImageDataGenerator
import cv2
import matplotlib.pyplot as plt
#读取图片
from keras.preprocessing.image import array_to_img
im=cv2.imread('D:/img.jpg')
im= cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = np.expand_dims(im, 0)
#设置生成器
datagen = ImageDataGenerator(
    rotation_range=90,
    width_shift_range=0.2,
    height_shift_range=0.2,
    channel_shift_range=10,
    shear_range = 20,  
    preprocessing_function=random_crop_image,
    vertical_flip=True,
    zoom_range=0.2,
    horizontal_flip=True)
datagen.fit(im)
#生成并画图
times=9
i = 0
for batch in datagen.flow(im, batch_size=1, save_to_dir='D:/',save_prefix='img_new', save_format='jpeg'):
    print(i)
    ax = plt.subplot(3,3,i+1)
    plt.sca(ax)
    plt.axis('off')
    ax.set_title('sample_%s'%(i),fontsize=7)
    plt.imshow(array_to_img(np.squeeze(batch)))
    i += 1
    if i==times:
        plt.show()
        break

640?wx_fmt=png

三、训练中的使用

直接在文件夹中生成增强后的样本难免会占据大量硬盘空间,此时可以将图像生成器与模型的数据生成器绑定,仅在训练时完成图像的读取与增强。(使用时,需将图片按类别分别存放在不同的子文件夹下,文件夹可以用类别命名。另外增强方法越多,训练的速度越慢,请考虑算力酌情开启)

#设置图像生成器参数
train_datagen = ImageDataGenerator(
    featurewise_center=True,
    preprocessing_function=random_crop_image,
    rotation_range=90,
    width_shift_range=0.3,
    height_shift_range=0.3,
    channel_shift_range=10,
    shear_range = 30,  
    vertical_flip=True,
    zoom_range=0.2,
    fill_mode='constant',
    horizontal_flip=True)
val_datagen=ImageDataGenerator(featurewise_center=True)
#从文件夹下一次以batchsize大小读取图片
batch_size=32
train_generator = train_datagen.flow_from_directory(
        'E:/train',
        target_size=(224, 224),
        batch_size=batch_size,
        shuffle=True)
val_generator = val_datagen.flow_from_directory(
        'E:/val',
        target_size=(224, 224),
        batch_size=batch_size)
#计算所有训练集与验证集的样本大小
alltrainfile=[]
for i in range(9):
    namelist=os.listdir('E:/train/'+str(i+1)+'/')
    alltrainfile.extend(namelist)
train_sample_num=len(alltrainfile)
allvalfile=[]
for i in range(9):
    namelist=os.listdir('E:/val/'+str(i+1)+'/')
    allvalfile.extend(namelist)
val_sample_num=len(allvalfile)
#构建模型
model = Sequential()
model.add(Conv2D(filters=32,kernel_size=(3,3),input_shape=(224,224,3),activation='relu',padding='same'))
...
model.add(Dense(class_number, activation='softmax'))
model.compile(loss='categorical_crossentropy',
      optimizer='adam',
      metrics=['accuracy'])
#训练模型
train_history=model.fit_generator(train_generator,
      steps_per_epoch=train_sample_num//batch_size,
      epochs=100,validation_data=val_generator,
      validation_steps=val_sample_num//batch_size,
      verbose=1,callbacks=callback_list,shuffle=True)

除了keras图像生成器中的增强方法外,其他的增强方法还有添加高斯噪音,图像锐化等,更高级的还可以使用GAN(对抗式生成网络)生成以假乱真的样本。

Python中文社区作为一个去中心化的全球技术社区,以成为全球20万Python中文开发者的精神部落为愿景,目前覆盖各大主流媒体和协作平台,与阿里、腾讯、百度、微软、亚马逊、开源中国、CSDN等业界知名公司和技术社区建立了广泛的联系,拥有来自十多个国家和地区数万名登记会员,会员来自以工信部、清华大学、北京大学、北京邮电大学、中国人民银行、中科院、中金、华为、BAT、谷歌、微软等为代表的政府机关、科研单位、金融机构以及海内外知名公司,全平台近20万开发者关注。

640?wx_fmt=jpeg

▼ 点击成为社区注册会员      「在看」一下,一起PY

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值