深度学习要取得较好的学习效果,通常对样本数量有一定的要求,在模型的研发过程中可以借助imagenet(具有1000多万张图片)等现成的大型数据集进行训练。但是在解决实际问题中,样本往往因为收集困难,缺乏历史数据等原因造成短缺,数量较少。
如何使用好手里有限的样本,进行充分利用,提升模型的泛化能力呢?除去模型及优化过程中的参数调节等原因,就样本本身,我们可以使用图像增强的方法。
一、什么是图像增强
简单的讲,图像增强就是利用已有图像通过一系列技术操作生成新样本的过程。常见的有翻转,旋转,平移,扭曲,或剪裁等。
使用Keras自带的图像增强模块—图像生成器(ImageDataGenerator)可帮我们快速实现上述过程。其有如下参数设置,其中前6个是对样本的预处理,若开启需对训练集和验证集同时使用,往后则为增强处理。
keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=0.,
width_shift_range=0.,
height_shift_range=0.,
shear_range=0.,
zoom_range=0.,
channel_shift_range=0.,
fill_mode='nearest',
cval=0.,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=K.image_data_format())
rotation_range:随机转动的最大角度
datagen = ImageDataGenerator(rotation_range=90)

datagen = ImageDataGenerator(width_shift_range=0.3)

datagen = ImageDataGenerator(height_shift_range=0.3)

datagen = ImageDataGenerator(zoom_range=0.5)
当为列表时,大于1表示缩小
datagen = ImageDataGenerator(zoom_range=[1.5,1.5])
小于1表示放大
datagen = ImageDataGenerator(zoom_range=[0.5,0.5])
介于1上下时,表示随机,且缩放比例不均衡
datagen = ImageDataGenerator(zoom_range=[0.5,1.5])
切变拉伸
datagen = ImageDataGenerator(shear_range = 50)
channel_shift_range:通道随机偏移的最大幅度
(效果为曝光度)
datagen = ImageDataGenerator(channel_shift_range=100)
翻转
datagen = ImageDataGenerator(horizontal_flip=True)
datagen = ImageDataGenerator(vertical_flip=True)
datagen = ImageDataGenerator(zoom_range=[3,3],fill_mode='nearest')
datagen = ImageDataGenerator(preprocessing_function=random_crop_image)
混合-完整代码
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
import cv2
import matplotlib.pyplot as plt
#读取图片
from keras.preprocessing.image import array_to_img
im=cv2.imread('D:/img.jpg')
im= cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = np.expand_dims(im, 0)
#设置生成器
datagen = ImageDataGenerator(
rotation_range=90,
width_shift_range=0.2,
height_shift_range=0.2,
channel_shift_range=10,
shear_range = 20,
preprocessing_function=random_crop_image,
vertical_flip=True,
zoom_range=0.2,
horizontal_flip=True)
datagen.fit(im)
#生成并画图
times=9
i = 0
for batch in datagen.flow(im, batch_size=1, save_to_dir='D:/',save_prefix='img_new', save_format='jpeg'):
print(i)
ax = plt.subplot(3,3,i+1)
plt.sca(ax)
plt.axis('off')
ax.set_title('sample_%s'%(i),fontsize=7)
plt.imshow(array_to_img(np.squeeze(batch)))
i += 1
if i==times:
plt.show()
break
三、训练中的使用
直接在文件夹中生成增强后的样本难免会占据大量硬盘空间,此时可以将图像生成器与模型的数据生成器绑定,仅在训练时完成图像的读取与增强。(使用时,需将图片按类别分别存放在不同的子文件夹下,文件夹可以用类别命名。另外增强方法越多,训练的速度越慢,请考虑算力酌情开启)
#设置图像生成器参数
train_datagen = ImageDataGenerator(
featurewise_center=True,
preprocessing_function=random_crop_image,
rotation_range=90,
width_shift_range=0.3,
height_shift_range=0.3,
channel_shift_range=10,
shear_range = 30,
vertical_flip=True,
zoom_range=0.2,
fill_mode='constant',
horizontal_flip=True)
val_datagen=ImageDataGenerator(featurewise_center=True)
#从文件夹下一次以batchsize大小读取图片
batch_size=32
train_generator = train_datagen.flow_from_directory(
'E:/train',
target_size=(224, 224),
batch_size=batch_size,
shuffle=True)
val_generator = val_datagen.flow_from_directory(
'E:/val',
target_size=(224, 224),
batch_size=batch_size)
#计算所有训练集与验证集的样本大小
alltrainfile=[]
for i in range(9):
namelist=os.listdir('E:/train/'+str(i+1)+'/')
alltrainfile.extend(namelist)
train_sample_num=len(alltrainfile)
allvalfile=[]
for i in range(9):
namelist=os.listdir('E:/val/'+str(i+1)+'/')
allvalfile.extend(namelist)
val_sample_num=len(allvalfile)
#构建模型
model = Sequential()
model.add(Conv2D(filters=32,kernel_size=(3,3),input_shape=(224,224,3),activation='relu',padding='same'))
...
model.add(Dense(class_number, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
#训练模型
train_history=model.fit_generator(train_generator,
steps_per_epoch=train_sample_num//batch_size,
epochs=100,validation_data=val_generator,
validation_steps=val_sample_num//batch_size,
verbose=1,callbacks=callback_list,shuffle=True)
除了keras图像生成器中的增强方法外,其他的增强方法还有添加高斯噪音,图像锐化等,更高级的还可以使用GAN(对抗式生成网络)生成以假乱真的样本。
Python中文社区作为一个去中心化的全球技术社区,以成为全球20万Python中文开发者的精神部落为愿景,目前覆盖各大主流媒体和协作平台,与阿里、腾讯、百度、微软、亚马逊、开源中国、CSDN等业界知名公司和技术社区建立了广泛的联系,拥有来自十多个国家和地区数万名登记会员,会员来自以工信部、清华大学、北京大学、北京邮电大学、中国人民银行、中科院、中金、华为、BAT、谷歌、微软等为代表的政府机关、科研单位、金融机构以及海内外知名公司,全平台近20万开发者关注。
▼ 点击成为社区注册会员 「在看」一下,一起PY