为了尽量利用我们有限的训练数据, 我们将通过一系列随机变换对数据进行提升, 这样我们的模型将看不到任何两张完全相同的图片, 这有利于我们抑制过拟合, 使得模型的泛化能力更好。
在Keras中, 这个步骤可以通过keras.preprocessing.image.ImageDataGenerator来实现。
ImageDataGenerator class
keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-06,
rotation_range=0,
width_shift_range=0.0,
height_shift_range=0.0,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='nearest',
cval=0.0,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0, dtype=None)
具体的代码如下:
class myAugmentation(object):
"""
一个用于图像增强的类:
首先:分别读取训练的图片和标签,然后将图片和标签合并用于下一个阶段使用
然后:使用Keras的预处理来增强图像
最后:将增强后的图片分解开,分为训练图片和训练标签
"""
def __init__(self, train_path="../deform/train", label_path="../deform/label", merge_path="../DataGen/merge", aug_merge_path="../DataGen/aug_merge",
aug_train_path="../DataGen/aug_train", aug_label_path="../DataGen/aug_label"):
"""
使用glob从路径中得到所有的“.img_type”文件,初始化类:__init__()
"""
self.train_imgs = glob.glob(train_path + "/*" )
self.label_imgs = glob.glob(label_path + "/*" )
self.train_path = train_path
self.label_path = label_path
self.merge_path = merge_path
self.aug_merge_path = aug_merge_path
self.aug_train_path = aug_train_path
self.aug_label_path = aug_label_path
self.slices = len(self.train_imgs)
self.datagen = ImageDataGenerator(
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.05,
zoom_range=0.1,
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest')
if not os.path.exists(self.merge_path):
os.mkdir(self.merge_path)
if not os.path.exists(self.aug_merge_path):
os.mkdir(self.aug_merge_path)
if not os.path.exists(self.aug_label_path):
os.mkdir(self.aug_label_path)
if not os.path.exists(self.aug_train_path):
os.mkdir(self.aug_train_path)
def Augmentation(self):
"""
Start augmentation.....
"""
trains = self.train_imgs
labels = self.label_imgs
path_train = self.train_path
path_label = self.label_path
path_merge = self.merge_path
path_aug_merge = self.aug_merge_path
if len(trains) != len(labels) or len(trains) == 0 or len(trains) == 0:
print("trains can't match labels")
return 0
for i in range(len(trains)):
img_t = load_img(trains[i])
img_l = load_img(labels[i])
x_t = img_to_array(img_t)
x_l = img_to_array(img_l)
x_t[:, :, 2] = x_l[:, :, 0]
img_tmp = array_to_img(x_t)
img_tmp.save(path_merge + "/" + str(i) + ".tif")
img = x_t
img = img.reshape((1,) + img.shape)
savedir = path_aug_merge + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
self.doAugmentate(img, savedir, str(i))
def doAugmentate(self, img, save_to_dir, save_prefix, batch_size=1, save_format='tif', imgnum=10):
# 增强一张图片的方法
"""
augmentate one image
"""
datagen = self.datagen
i = 0
for batch in datagen.flow(img,
batch_size=batch_size,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format):
i += 1
if i >= imgnum:
break
def splitMerge(self):
# 将合在一起的图片分开
"""
split merged image apart
"""
path_merge = self.aug_merge_path
path_train = self.aug_train_path
path_label = self.aug_label_path
for i in range(self.slices):
path = path_merge + "/" + str(i)
train_imgs = glob.glob(path + "/*.tif")
savedir = path_train + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
savedir = path_label + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
for imgname in train_imgs:
midname = imgname[imgname.rindex("\\") + 1:]
img = cv2.imread(imgname)
img_train = img[:, :, 2] # cv2 read image rgb->bgr
img_label = img[:, :, 0]
cv2.imwrite(path_train + "/" + str(i) + "/" + midname, img_train)
cv2.imwrite(path_label + "/" + str(i) + "/" + midname, img_label)
tips:要告知程序:
- train_path:原始的训练数据集
- label_path:原始的标签数据集
- merge_path:原始数据集融合后的存放地址
- aug_merge_path:融合数据增强后的存放地址
- aug_train_path:增强后的训练数据集
- aug_label_path:增强后的标签数据集
使用:
if __name__ == "__main__":
aug = myAugmentation()
aug.Augmentation()
aug.splitMerge()