import os
import cv2
import random
import torch
import numpy as np
from PIL import Image
from cvtransforms import transforms
from cvtransforms import functional as tf
import datetime
'''
语义分割数据增强时,需将图像和标签图同时操作,对于旋转,偏移等操作,会引入黑边(均为0值),
将引入的黑边 视为1类,标签值默认为0,真实标签从1开始。
图像采用BILINEAR,标签图采用NEAREST
目前采用 torchvision.transforms.functional 的API,此api与PIL的数据增强操作是一致的,只要转成PIL,均采用uint8
功能包括随机旋转,水平翻转,垂直翻转,随机裁剪,随机缩放,高斯噪声,椒盐噪声
'''
class Augmentations_PIL:
def __init__(self, augment_mode):
self.augment_dict=augment_mode
self.image_fill = 0 # image fill=0,0对应黑边
self.label_fill = 0 # label fill=0,0对应黑边
self.input_hw=[512,512]
'''
train 阶段
以下操作,均为单操作,不可组合!,所有的操作输出均需要resize至input_hw
且 image为3 channel,label为1 channel
且 输入均为RGB-3通道
image:[HWC], label:[HW]
'''
def rotation(self, image, label):
'''
:param image: PIL RGB uint8
:param label: PIL, uint8
:param angle: None, list-float, tuple-float
:return: PIL
'''
angle1 = self.augment_dict["rotation"]["degrees"]
# if angle is None:
angle = transforms.RandomRotation.get_params(angle1)
# elif isinstance(angle, list) or isinstance(angle, tuple):
# angle = random.choice(angle)
image = tf.rotate(image, angle, fill=self.image_fill)
label = tf.rotate(label, angle, fill=self.label_fill)
return image, label
def flipH(self, image, label):
image = tf.hflip(image) # 水平翻转
label = tf.hflip(label)
return image, label
def flipV(self,image,label):
image = tf.vflip(image) # 垂直翻转
label = tf.vflip(label)
return image,label
# zoom in 随机裁剪
def resizecrop(self, image, label):
size ,scale = self.augment_dict["resizecrop"]["size"], self.augment_dict["resizecrop"]["scale"]
image=transforms.RandomResizedCrop(size,(scale,1.0))(image)
label=transforms.RandomResizedCrop(size,(scale,1.0))(label)
return image, label
# zoom out
def perspective(self, image, label):
# 透视变换 RandomPerspective
distortion_scale = self.augment_dict["perspective"]["distortion_scale"]
width, height = image.size
startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, distortion_scale)
# 0值填充,仍是原始图像大小,需要resize
image = tf.perspective(image, startpoints, endpoints, interpolation=Image.BICUBIC, fill=self.image_fill)
label = tf.perspective(label, startpoints, endpoints, interpolation=Image.NEAREST, fill=self.label_fill)
return image, label
def affine(self, image, label):
# 随机仿射(随机偏移,随机旋转,随机放缩等整合)
# TODO 将degrees等参数传出,由用户设置
# 随机旋转-平移-缩放-错切 4种仿射变换 pytorch实现的是保持中心不变 不错切
degrees, translate, scale_ranges = self.augment_dict["affine"]["degrees"],self.augment_dict["affine"]["translate"],self.augment_dict["affine"]["scale"]
shears=self.augment_dict["affine"]["shear"]
ret = transforms.RandomAffine.get_params(degrees, translate, scale_ranges,shears, img_size=image.size)
# angle, translations, scale, shear = ret
# 0值填充,仍是原始图像大小,需要resize
image = tf.affine(image, *ret, resample=0, fillcolor=self.image_fill) # PIL.Image.NEAREST
label = tf.affine(label, *ret, resample=0, fillcolor=self.label_fill)
return image, label
def colorjitter(self, image, label):
# 随机颜色增强,这里的随机是值,而非发生概率:transforms.RandomApply
brightness,contrast= self.augment_dict["colorjitter"]["brightness"], self.augment_dict["colorjitter"]["contrast"]
saturation ,hue=self.augment_dict["colorjitter"]["saturation"], self.augment_dict["colorjitter"]["hue"]
transforms_func = transforms.ColorJitter(brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
image = transforms_func(image)
return image, label
# gassian noise
def gaussiannoise(self, image, label):
mean ,std = self.augment_dict["gaussiannoise"]["mean"], self.augment_dict["gaussiannoise"]["std"]
transforms_func = transforms.RandomGaussianNoise(0.9, mean, std)
image = transforms_func(image)
return image, label
def spnoise(self, image, label):
prob=self.augment_dict["spnoise"]["prob"]
transforms_func = transforms.RandomSPNoise(0.9, prob)
image = transforms_func(image)
return image, label
def self_images(self,image,label):
return image, label
class Transforms_PIL(object):
def __init__(self, augment_list):
self.aug_pil = Augmentations_PIL(augment_list)
self.augment_ways=augment_list
self.aug_funcs = [a for a in self.aug_pil.__dir__() if not a.startswith('_') and a not in self.aug_pil.__dict__]
def __call__(self, image, label):
'''
:param image: PIL RGB uint8
:param label: PIL, uint8
:return: PIL
'''
if len(self.augment_ways.keys())==0:
print("增强无效,请检查!")
return -1
elif len(self.augment_ways.keys())==1:
name=list(self.augment_ways.keys())[0] #取出来并转成字符串的形式
image, label = getattr(self.aug_pil, str(name))(image, label)
return image, label
else:
name=np.random.choice(list(self.augment_ways.keys()))
image, label = getattr(self.aug_pil, str(name))(image, label)
return image, label
class ToTensor(object):
# image label -> tensor, image div 255
def __call__(self, image, label):
# PIL uint8
image = tf.to_tensor(image) # transpose HWC->CHW, /255
label = torch.from_numpy(np.array(label)) # PIL->ndarray->tensor
if not isinstance(label, torch.LongTensor):
label = label.long()
return image, label
def get_current_time():
current_time = datetime.datetime.now()
year = str(current_time.year)
month = str(current_time.month)
day = str(current_time.day)
hour = str(current_time.hour)
minute = str(current_time.minute)
second = str(current_time.second)
microsecond = str(current_time.microsecond)
current_time_str = year + month + day + hour + minute + second + "_" + microsecond
return current_time_str
def test_01():
augment_mode = {}
augment_mode["flipH"] = True
augment_mode["flipV"] = True
tem = {}
# tem["size"] = 448
# tem["scale"] = (0.8,1)
# augment_mode["resizecrop"] = tem.copy()
#
# tem.clear()
tem["prob"] = 0.05
augment_mode["spnoise"] = tem.copy()
tem_3 = {}
tem_3["mean"] = 0.12
tem_3["std"] = 0.1
augment_mode["gaussiannoise"] = tem_3
#
# # tem_4 = {}
# # tem_4["fov"] = 10
# # tem_4["anglex"] = 10
# # tem_4["angley"] = 10
# # tem_4["anglez"] = 10
# # tem_4["shear"] = 10
# # tem_4["translate"] = (0.05,0.05)
# # tem_4["scale"] = (0.95,1.02)
# # augment_mode["perspective"] = tem_4
#
# tem_5 = {}
# tem_5["degrees"] = (-10,10)
# tem_5["translate"] = (0.05,0.05)
# tem_5["scale"] = (0.95,1.02)
# tem_5["shear"] = 10
# augment_mode["affine"] = tem_5
#
tem_6 = {}
tem_6["degrees"] = (-60,80)
augment_mode["rotation"] = tem_6
#
tem_7 = {}
tem_7["brightness"] = 0.56
tem_7["contrast"] = 0.55
tem_7["saturation"] = 0.12
tem_7["hue"] = 0.1
augment_mode["colorjitter"] = tem_7
#augment_mode["Augment_ratio"]=2
#augment_mode["DataAugment_if"] = True
return augment_mode
if __name__ == '__main__':
# aug_pil = Augmentations_PIL()
# # dir包含 属性-所有方法,dict只包含属性
# print(aug_pil.__dict__)
# aug_funcs = [a for a in aug_pil.__dir__() if not a.startswith('_') and a not in aug_pil.__dict__]
#
import os
a=test_01()
trans = Transforms_PIL(a)
num = 4
path = r"E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\image"
path1 = r"E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\label"
for i in range(num):
for name in os.listdir(path):
if name.split('.')[-1] in ['jpg', 'png', 'bmp', 'PNG']:
image_path = os.path.join(path, name)
label_path = os.path.join(path1, name.split('.')[0] + '.' + os.listdir(path1)[0].split('.')[-1])
image = Image.open(image_path)
label = Image.open(label_path)
image1, label1 = trans(image, label)
current_time = get_current_time()
image1.save("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\image\\" + str(current_time) + ".bmp")
label1.save("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\label\\" + str(current_time) + ".png")
print("E:\\AI_datas\\2021_12_20_sengment_data2\\train2\\train\\label\\" + str(current_time) + ".png")
语义分割数据增强
最新推荐文章于 2023-08-16 18:36:40 发布