数据增强
import random
import numpy as np
import cv2
from PIL import Image
import os
base_size = 1024
scale = True
rotate = True
crop_size = 1200
flip = True
blur = False
def augmentation(pp_path, image, label):
h, w, _ = image.shape
if base_size:
if scale:
longside = random.randint(int(base_size * 0.5), int(base_size * 2.0))
else:
longside = base_size
h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (int(1.0 * longside * h / w + 0.5), longside)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
rgb_resize = Image.fromarray(image)
rgb_resize.save(os.path.join(pp_path, "rgb_resize.jpg"))
h, w, _ = image.shape
if rotate:
angle = random.randint(-10, 10)
center = (w / 2, h / 2)
rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
image = cv2.warpAffine(image, rot_matrix, (w, h), flags=cv2.INTER_LINEAR)
label = cv2.warpAffine(label, rot_matrix, (w, h), flags=cv2.INTER_NEAREST)
rgb_rotate = Image.fromarray(image)
rgb_rotate.save(os.path.join(pp_path, "rgb_rotate.jpg"))
if crop_size:
pad_h = max(crop_size - h, 0)
pad_w = max(crop_size - w, 0)
pad_kwargs = {
"top": 0,
"bottom": pad_h,
"left": 0,
"right": pad_w,
"borderType": cv2.BORDER_CONSTANT, }
if pad_h > 0 or pad_w > 0:
image = cv2.copyMakeBorder(image, value=0, **pad_kwargs)
label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
h, w, _ = image.shape
start_h = random.randint(0, h - crop_size)
start_w = random.randint(0, w - crop_size)
end_h = start_h + crop_size
end_w = start_w + crop_size
image = image[start_h:end_h, start_w:end_w]
label = label[start_h:end_h, start_w:end_w]
rgb_crop = Image.fromarray(image)
rgb_crop.save(os.path.join(pp_path, "rgb_crop.jpg"))
if flip:
if random.random() > 0.5:
image = np.fliplr(image).copy()
label = np.fliplr(label).copy()
rgb_flip = Image.fromarray(image)
rgb_flip.save(os.path.join(pp_path, "rgb_flip.jpg"))
if blur:
sigma = random.random()
ksize = int(3.3 * sigma)
ksize = ksize + 1 if ksize % 2 == 0 else ksize
image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101)
return image, label
if __name__ == '__main__':
pp_path = "/data/realseeData/test/pano-surface/73de190196212982ee0ed55ef3686496/derived/1552296451"
rgb_image_array = np.array(Image.open(os.path.join(pp_path, "rgb_image_align.jpg")))
label_image_array = np.array(Image.open(os.path.join(pp_path, pp_path.split("/")[-1]+".png")))
image, label = augmentation(pp_path, rgb_image_array, label_image_array)