语义分割数据增强

import os
import cv2
import random
import torch
import numpy as np
from PIL import Image
from cvtransforms import transforms
from cvtransforms import functional as tf

import datetime

'''
    语义分割数据增强时,需将图像和标签图同时操作,对于旋转,偏移等操作,会引入黑边(均为0),
    将引入的黑边 视为1类,标签值默认为0,真实标签从1开始。
    图像采用BILINEAR,标签图采用NEAREST
    目前采用 torchvision.transforms.functional 的API,此api与PIL的数据增强操作是一致的,只要转成PIL,均采用uint8
    功能包括随机旋转,水平翻转,垂直翻转,随机裁剪,随机缩放,高斯噪声,椒盐噪声
'''

class Augmentations_PIL:
    def __init__(self, augment_mode):
        self.augment_dict=augment_mode
        self.image_fill = 0  # image fill=00对应黑边
        self.label_fill = 0  # label fill=00对应黑边
        self.input_hw=[512,512]
    '''
    train 阶段
    以下操作,均为单操作,不可组合!,所有的操作输出均需要resize至input_hw
    且 image为3 channel,label为1 channel
    且 输入均为RGB-3通道
    image:[HWC], label:[HW]
    '''
    def rotation(self, image, label):
        '''
        :param image:  PIL RGB uint8
        :param label:  PIL, uint8
        :param angle:  None, list-float, tuple-float
        :return:  PIL
        '''
        angle1 = self.augment_dict["rotation"]["degrees"]
        # if angle is None:
        angle = transforms.RandomRotation.get_params(angle1)
        # elif isinstance(angle, list) or isinstance(angle, tuple):
        #     angle = random.choice(angle)
        image = tf.rotate(image, angle, fill=self.image_fill)
        label = tf.rotate(label, angle, fill=self.label_fill)
        return image, label

    def flipH(self, image, label):
        image = tf.hflip(image)  # 水平翻转
        label = tf.hflip(label)
        return image, label

    def flipV(self,image,label):
        image = tf.vflip(image)  # 垂直翻转
        label = tf.vflip(label)
        return image,label

    # zoom in 随机裁剪
    def resizecrop(self, image, label):

        size ,scale = self.augment_dict["resizecrop"]["size"], self.augment_dict["resizecrop"]["scale"]

        image=transforms.RandomResizedCrop(size,(scale,1.0))(image)
        label=transforms.RandomResizedCrop(size,(scale,1.0))(label)

        return image, label

    # zoom out
    def perspective(self, image, label):
        # 透视变换 RandomPerspective
        distortion_scale = self.augment_dict["perspective"]["distortion_scale"]

        width, height = image.size
        startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, distortion_scale)
        # 0值填充,仍是原始图像大小,需要resize
        image = tf.perspective(image, startpoints, endpoints, interpolation=Image.BICUBIC, fill=self.image_fill)
        label = tf.perspective(label, startpoints, endpoints, interpolation=Image.NEAREST, fill=self.label_fill)

        return image, label

    def affine(self, image, label):
        # 随机仿射(随机偏移,随机旋转,随机放缩等整合)
        # TODO 将degrees等参数传出,由用户设置
        # 随机旋转-平移-缩放-错切 4种仿射变换 pytorch实现的是保持中心不变 不错切

        degrees, translate, scale_ranges = self.augment_dict["affine"]["degrees"],self.augment_dict["affine"]["translate"],self.augment_dict["affine"]["scale"]
        shears=self.augment_dict["affine"]["shear"]
        ret = transforms.RandomAffine.get_params(degrees, translate, scale_ranges,shears, img_size=image.size)
        # angle, translations, scale, shear = ret
        # 0值填充,仍是原始图像大小,需要resize
        image = tf.affine(image, *ret, resample=0, fillcolor=self.image_fill)  # PIL.Image.NEAREST
        label = tf.affine(label, *ret, resample=0, fillcolor=self.label_fill)

        return image, label

    def colorjitter(self, image, label):
        # 随机颜色增强,这里的随机是值,而非发生概率:transforms.RandomApply
        brightness,contrast= self.augment_dict["colorjitter"]["brightness"], self.augment_dict["colorjitter"]["contrast"]
        saturation ,hue=self.augment_dict["colorjitter"]["saturation"], self.augment_dict["colorjitter"]["hue"]
        transforms_func = transforms.ColorJitter(brightness=brightness,
                                                 contrast=contrast,
                                                 saturation=saturation,
                                                 hue=hue)
        image = transforms_func(image)
        return image, label

    # gassian noise
    def gaussiannoise(self, image, label):
        mean ,std = self.augment_dict["gaussiannoise"]["mean"], self.augment_dict["gaussiannoise"]["std"]
        transforms_func = transforms.RandomGaussianNoise(0.9, mean, std)
        image = transforms_func(image)

        return image, label

    def spnoise(self, image, label):
        prob=self.augment_dict["spnoise"]["prob"]
        transforms_func = transforms.RandomSPNoise(0.9, prob)
        image = transforms_func(image)
        return image, label
    def self_images(self,image,label):
        return image, label


class Transforms_PIL(object):
    def __init__(self, augment_list):
        self.aug_pil = Augmentations_PIL(augment_list)
        self.augment_ways=augment_list
        self.aug_funcs = [a for a in self.aug_pil.__dir__() if not a.startswith('_') and a not in self.aug_pil.__dict__]

    def __call__(self, image, label):
        '''
        :param image:  PIL RGB uint8
        :param label:  PIL, uint8
        :return:  PIL
        '''
        if len(self.augment_ways.keys())==0:
            print("增强无效,请检查!")
            return -1
        elif len(self.augment_ways.keys())==1:
            name=list(self.augment_ways.keys())[0] #取出来并转成字符串的形式
            image, label = getattr(self.aug_pil, str(name))(image, label)
            return image, label
        else:
            name=np.random.choice(list(self.augment_ways.keys()))
            image, label = getattr(self.aug_pil, str(name))(image, label)
            return image, label

class ToTensor(object):
    # image label -> tensor, image div 255
    def __call__(self, image, label):
        # PIL uint8
        image = tf.to_tensor(image)  # transpose HWC->CHW, /255
        label = torch.from_numpy(np.array(label))  # PIL->ndarray->tensor
        if not isinstance(label, torch.LongTensor):
            label = label.long()
        return image, label

def get_current_time():
    current_time = datetime.datetime.now()
    year = str(current_time.year)
    month = str(current_time.month)
    day = str(current_time.day)
    hour = str(current_time.hour)
    minute = str(current_time.minute)
    second = str(current_time.second)
    microsecond = str(current_time.microsecond)
    current_time_str = year + month + day + hour + minute + second + "_" + microsecond
    return current_time_str

def test_01():
    augment_mode = {}

    augment_mode["flipH"] = True
    augment_mode["flipV"] = True

    tem = {}
    # tem["size"] = 448
    # tem["scale"] = (0.8,1)
    # augment_mode["resizecrop"] = tem.copy()
    #
    # tem.clear()
    tem["prob"] = 0.05
    augment_mode["spnoise"] = tem.copy()

    tem_3 = {}
    tem_3["mean"] = 0.12
    tem_3["std"] = 0.1
    augment_mode["gaussiannoise"] = tem_3
    #
    # # tem_4 = {}
    # # tem_4["fov"] = 10
    # # tem_4["anglex"] = 10
    # # tem_4["angley"] = 10
    # # tem_4["anglez"] = 10
    # # tem_4["shear"] = 10
    # # tem_4["translate"] = (0.05,0.05)
    # # tem_4["scale"] = (0.95,1.02)
    # # augment_mode["perspective"] = tem_4
    #
    # tem_5 = {}
    # tem_5["degrees"] = (-10,10)
    # tem_5["translate"] = (0.05,0.05)
    # tem_5["scale"] = (0.95,1.02)
    # tem_5["shear"] = 10
    # augment_mode["affine"] = tem_5
    #
    tem_6 = {}
    tem_6["degrees"] = (-60,80)
    augment_mode["rotation"] = tem_6
    #
    tem_7 = {}
    tem_7["brightness"] = 0.56
    tem_7["contrast"] = 0.55
    tem_7["saturation"] = 0.12
    tem_7["hue"] = 0.1
    augment_mode["colorjitter"] = tem_7
    #augment_mode["Augment_ratio"]=2
    #augment_mode["DataAugment_if"] = True

    return augment_mode
if __name__ == '__main__':
    # aug_pil = Augmentations_PIL()
    # # dir包含 属性-所有方法,dict只包含属性
    # print(aug_pil.__dict__)
    # aug_funcs = [a for a in aug_pil.__dir__() if not a.startswith('_') and a not in aug_pil.__dict__]
    #
    import os
    a=test_01()

    trans = Transforms_PIL(a)

    num = 4
    path = r"E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\image"
    path1 = r"E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\label"
    for i in range(num):
        for name in os.listdir(path):
            if name.split('.')[-1] in ['jpg', 'png', 'bmp', 'PNG']:
                image_path = os.path.join(path, name)
                label_path = os.path.join(path1, name.split('.')[0] + '.' + os.listdir(path1)[0].split('.')[-1])
                image = Image.open(image_path)
                label = Image.open(label_path)
                image1, label1 = trans(image, label)
                current_time = get_current_time()
                image1.save("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\image\\" + str(current_time) + ".bmp")
                label1.save("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\label\\" + str(current_time) + ".png")
                print("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\label\\" + str(current_time) + ".png")


  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

沐雨清风 จุ๊บ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值