图像分割(2):小数据集的数据扩充

为了尽量利用我们有限的训练数据, 我们将通过一系列随机变换对数据进行提升, 这样我们的模型将看不到任何两张完全相同的图片, 这有利于我们抑制过拟合, 使得模型的泛化能力更好。

在Keras中, 这个步骤可以通过keras.preprocessing.image.ImageDataGenerator来实现。
ImageDataGenerator class

keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False, 
    samplewise_center=False, 
    featurewise_std_normalization=False, 
    samplewise_std_normalization=False, 
    zca_whitening=False, 
    zca_epsilon=1e-06, 
    rotation_range=0, 
    width_shift_range=0.0, 
    height_shift_range=0.0, 
    brightness_range=None, 
    shear_range=0.0, 
    zoom_range=0.0, 
    channel_shift_range=0.0, 
    fill_mode='nearest', 
    cval=0.0, 
    horizontal_flip=False, 
    vertical_flip=False, 
    rescale=None, 
    preprocessing_function=None, 
    data_format=None, 
    validation_split=0.0, dtype=None)

具体的代码如下:

class myAugmentation(object):
    """
    一个用于图像增强的类:
    首先:分别读取训练的图片和标签,然后将图片和标签合并用于下一个阶段使用
    然后:使用Keras的预处理来增强图像
    最后:将增强后的图片分解开,分为训练图片和训练标签
    """

    def __init__(self, train_path="../deform/train", label_path="../deform/label", merge_path="../DataGen/merge", aug_merge_path="../DataGen/aug_merge",
                 aug_train_path="../DataGen/aug_train", aug_label_path="../DataGen/aug_label"):

        """
        使用glob从路径中得到所有的“.img_type”文件,初始化类:__init__()
        """
        self.train_imgs = glob.glob(train_path + "/*" )
        self.label_imgs = glob.glob(label_path + "/*" )
        self.train_path = train_path
        self.label_path = label_path
        self.merge_path = merge_path
        self.aug_merge_path = aug_merge_path
        self.aug_train_path = aug_train_path
        self.aug_label_path = aug_label_path
        self.slices = len(self.train_imgs)
        self.datagen = ImageDataGenerator(
            rotation_range=180,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.05,
            zoom_range=0.1,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='nearest')
        if not os.path.exists(self.merge_path):
            os.mkdir(self.merge_path)
        if not os.path.exists(self.aug_merge_path):
            os.mkdir(self.aug_merge_path)
        if not os.path.exists(self.aug_label_path):
            os.mkdir(self.aug_label_path)
        if not os.path.exists(self.aug_train_path):
            os.mkdir(self.aug_train_path)

    def Augmentation(self):

        """
        Start augmentation.....
        """
        trains = self.train_imgs
        labels = self.label_imgs
        path_train = self.train_path
        path_label = self.label_path
        path_merge = self.merge_path
        path_aug_merge = self.aug_merge_path
        if len(trains) != len(labels) or len(trains) == 0 or len(trains) == 0:
            print("trains can't match labels")
            return 0
        for i in range(len(trains)):
            img_t = load_img(trains[i])
            img_l = load_img(labels[i])
            x_t = img_to_array(img_t)
            x_l = img_to_array(img_l)

            x_t[:, :, 2] = x_l[:, :, 0]
            img_tmp = array_to_img(x_t)
            img_tmp.save(path_merge + "/" + str(i) + ".tif")
            img = x_t
            img = img.reshape((1,) + img.shape)
            savedir = path_aug_merge + "/" + str(i)
            if not os.path.lexists(savedir):
                os.mkdir(savedir)
            self.doAugmentate(img, savedir, str(i))

    def doAugmentate(self, img, save_to_dir, save_prefix, batch_size=1, save_format='tif', imgnum=10):
        # 增强一张图片的方法
        """
        augmentate one image
        """
        datagen = self.datagen
        i = 0
        for batch in datagen.flow(img,
                                  batch_size=batch_size,
                                  save_to_dir=save_to_dir,
                                  save_prefix=save_prefix,
                                  save_format=save_format):
            i += 1
            if i >= imgnum:
                break

    def splitMerge(self):
        # 将合在一起的图片分开
        """
        split merged image apart
        """
        path_merge = self.aug_merge_path
        path_train = self.aug_train_path
        path_label = self.aug_label_path

        for i in range(self.slices):
            path = path_merge + "/" + str(i)
            train_imgs = glob.glob(path + "/*.tif")
            savedir = path_train + "/" + str(i)
            if not os.path.lexists(savedir):
                os.mkdir(savedir)
            savedir = path_label + "/" + str(i)

            if not os.path.lexists(savedir):
                os.mkdir(savedir)
            for imgname in train_imgs:

                midname = imgname[imgname.rindex("\\") + 1:]
                img = cv2.imread(imgname)
                img_train = img[:, :, 2]  # cv2 read image rgb->bgr
                img_label = img[:, :, 0]
                cv2.imwrite(path_train + "/" + str(i) + "/" + midname, img_train)
                cv2.imwrite(path_label + "/" + str(i) + "/" + midname, img_label)

tips:要告知程序:

  1. train_path:原始的训练数据集
  2. label_path:原始的标签数据集
  3. merge_path:原始数据集融合后的存放地址
  4. aug_merge_path:融合数据增强后的存放地址
  5. aug_train_path:增强后的训练数据集
  6. aug_label_path:增强后的标签数据集

使用:

if __name__ == "__main__":
    aug = myAugmentation()
    aug.Augmentation()
    aug.splitMerge()
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

书生伯言

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值