数据增强(Data Augmentation)
(本人水平有限,请广大读者批评指正!!!!)
深度学习通常需要大量的数据作为支撑,看到那些公开的数据集,少的也有几十万张,但是在现实中,我们能拥有的数据集网络没有那么到。但是数据量少,往往会造成过拟合等问题,因此需要一些“奇巧淫技”来增强数据,正好本人在看斯坦福的CS231N课程中的这方面介绍,因此做个总结。
结合课程和网上查看的资料,将Data Augmentation总结如下:
1、水平/竖直翻转。
2、随机crop。
3、颜色改变。
4、仿射/旋转变换
5、随机改变大小
6、加噪声
7、·······
下面对上述方法中部分进行具体介绍:
1、Keras
Keras是以tensorflow或theano作为后端的一个极易上手的框架,本人比较懒,所以研究生阶段用的最多的也就是Keras。在Keras中专门有一个图像数据增加的工具ImageDataGenerator。它能满足数据增强的大部分需求。
直接上代码:
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
Datagen = ImageDataGenerator(rotation_range=40,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip = True
fill_mode='nearest')
#还有其他一些参数,具体请看:https://keras.io/preprocessing/image/ ,如去均值,标准化,ZCA白化,旋转,
#偏移,翻转,缩放等
img = load_img('../data/hand1.jpg')#获取一个PIL图像
x_img = img_to_array(img)
x_img = x_img.reshape((1,)+ x_img.shape)
i = 0
for img_batch in Datagen.flow(x_img,
batch_size=1,
save_to_dir='../data/pre_Data/'
save_prefix='hand',
save_format='jpeg'):
i +=1
if i > 20:
break
2、caffe
现在在用caffe,想研究一下caffe中的数据增强,只发现mirror、scale、crop三种,查看了一些资料,需要自己添加一些数据增强的代码,因此最近一直在研究caffe源码(一个c++很烂的人研究源码,也是蛋疼),。
后续补充。
3、提供一个链接:https://github.com/aleju/imgaug
看上去效果很好:
4、PCA Jittering的实现:
PCA Jittering最早是由Alex在他2012年赢得的ImageNet竞赛的那篇NIPS中提出的,首先按照RGB三个颜色通道计算均值和标准差,对网络的输入数据进行规范化,随后我们在整个训练集上计算了协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering。
本文根据:https://www.zhihu.com/question/35339639中提供的PCA Jittering的代码做了下实验。代码如下
# -*- coding: utf-8 -*-
"""
Created on Wed May 10 10:00:53 2017
@author: xx
"""
import numpy as np
import os
from PIL import Image, ImageOps
import argparse
import random
from scipy import misc
def PCA_Jittering(path):
img_list = os.listdir(path)
img_num = len(img_list)
for i in range(img_num):
img_path = os.path.join(path, img_list[i])
img = Image.open(img_path)
img = np.asanyarray(img, dtype = 'float32')
img = img / 255.0
img_size = img.size / 3
img1 = img.reshape(img_size, 3)
img1 = np.transpose(img1)
img_cov = np.cov([img1[0], img1[1], img1[2]])
lamda, p = np.linalg.eig(img_cov)
p = np.transpose(p)
alpha1 = random.normalvariate(0,3)
alpha2 = random.normalvariate(0,3)
alpha3 = random.normalvariate(0,3)
v = np.transpose((alpha1*lamda[0], alpha2*lamda[1], alpha3*lamda[2]))
add_num = np.dot(p,v)
img2 = np.array([img[:,:,0]+add_num[0], img[:,:,1]+add_num[1], img[:,:,2]+add_num[2]])
img2 = np.swapaxes(img2,0,2)
img2 = np.swapaxes(img2,0,1)
save_name = 'pre'+str(i)+'.png'
save_path = os.path.join(path, save_name)
misc.imsave(save_path,img2)
#plt.imshow(img2)
#plt.show()
效果如下: