python 离线数据增强(扩充数据集)

增强方式包括:
#     (一) 针对像素的数据增强
#     1. 改变亮度
#     2. 加噪声
#     (二) 针对图像的数据增强
#     3. 裁剪(需改变bbox)
#     4. 平移(需改变bbox)
#     5. 镜像(需要改变bbox)
#     6. 旋转(需要改变bbox)
#     7. 遮挡

难点:博主在进行亮度、噪声、裁剪、平移、镜像、遮挡的实现时,还是比较轻松的;但是在旋转的实现中,label的调整陷入了麻烦,但是最后通过改进函数的方式;

最初:

# X_MIN = min(X1, X2, X3, X4)
# X_MAX = max(X1, X2, X3, X4)
# Y_MIN = min(Y1, Y2, Y3, Y4)
# Y_MAX = max(Y1, Y2, Y3, Y4)

改进:

            NEW_X1 = (X1 + X3) / 2
            NEW_X2 = (X2 + X3) / 2
            NEW_X3 = (X2 + X4) / 2
            NEW_X4 = (X1 + X4) / 2

            NEW_Y1 = (Y1 + Y3) / 2
            NEW_Y2 = (Y2 + Y3) / 2
            NEW_Y3 = (Y2 + Y4) / 2
            NEW_Y4 = (Y1 + Y4) / 2
            X_MIN = min(NEW_X1, NEW_X2, NEW_X3, NEW_X4)
            X_MAX = max(NEW_X1, NEW_X2, NEW_X3, NEW_X4)
            Y_MIN = min(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)
            Y_MAX = max(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)

结果:旋转后的label可以很好的框住检测目标;

下面是完整代码:

# -*- coding=utf-8 -*-
# -----------------------------------------------------------------
# Description:
#     data augmentation for obeject detection
# 增强方式包括:
#     (一) 针对像素的数据增强
#     1. 改变亮度
#     2. 加噪声
#     (二) 针对图像的数据增强
#     3. 裁剪(需改变bbox)
#     4. 平移(需改变bbox)
#     5. 镜像(需要改变bbox)
#     6. 旋转(需要改变bbox)
#     7. 遮挡
# 注意:
#     random.seed(),相同的seed,产生的随机数是一样的!!

import time
import random
import cv2
import os
import copy
import math
import numpy as np
from PIL import Image
from skimage.util import random_noise
from skimage import exposure
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import ElementTree, Element
import time
import torch


def show_pic(img, bboxes=None):
    '''
    输入:
        img:    图像array
        bboxes: 图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
        names:  每个box对应的名称
    '''
    cv2.imwrite('./1.jpg', img)
    img = cv2.imread('./1.jpg')
    for i in range(len(bboxes)):
        bbox = bboxes[i]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[2]
        y_max = bbox[3]
        cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), thickness=3)
    cv2.namedWindow('pic', 0)  # 1表示原图
    cv2.moveWindow('pic', 0, 0)  # 将显示窗口移到显示屏的相应位置
    cv2.resizeWindow('pic', 1200, 800)  # 可视化的图片大小
    cv2.imshow('pic', img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()



# 图像均为cv2读取
class DataAugmentForObjectDetection():
    def __init__(self, rotation_rate=0.5, max_rotation_angle=5,
                 crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,
                 add_noise_rate=0.5, flip_rate=0.5,
                 erase_rate=0.5, erase_length=10, erase_holes=1, erase_threshold=0.5):
        self.rotation_rate = rotation_rate
        self.max_rotation_angle = max_rotation_angle
        self.crop_rate = crop_rate
        self.shift_rate = shift_rate
        self.change_light_rate = change_light_rate
        self.add_noise_rate = add_noise_rate
        self.flip_rate = flip_rate
        self.erase_rate = erase_rate
        self.erase_length = erase_length
        self.erase_holes = erase_holes
        self.erase_threshold = erase_threshold

    # 加噪声
    def sp_noise(self,image, prob):
        '''
        添加椒盐噪声
        prob:噪声比例
        '''
        output = np.zeros(image.shape, np.uint8)
        thres = 1 - prob
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                rdn = random.random()
                if rdn < prob:
                    output[i][j] = 0
                elif rdn > thres:
                    output[i][j] = 255
                else:
                    output[i][j] = image[i][j]
        return output

    def gasuss_noise(self,image, mean=0, var=0.00001):
        '''
            添加高斯噪声
            mean : 均值
            var : 方差
        '''
        image = np.array(image / 255, dtype=float)
        noise = np.random.normal(mean, var ** 0.5, image.shape)
        out = image + noise
        if out.min() < 0:
            low_clip = -1.
        else:
            low_clip = 0.
        out = np.clip(out, low_clip, 1.0)
        out = np.uint8(out * 255)
        # cv.imshow("gasuss", out)
        return out

    def addNoise(self, img):
        '''
        输入:
            img:图像array
        输出:
            加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
        '''
        # random.seed(int(time.time()))
        # return random_noise(img, mode='gaussian', seed=int(time.time()), clip=True)*255
        img = np.int(random_noise(img, mode='gaussian', clip=True, mean=0, var=0.00001)*255)
        return img

    # 调整亮度
    def changeLight(self, img):
        # random.seed(int(time.time()))
        flag = random.uniform(0.5, 1.5)  # flag>1为调暗,小于1为调亮
        img = np.abs(img)
        # if np.min(img) < 0:
        #     print(np.min(img))
        return exposure.adjust_gamma(img, flag)

    # 裁剪
    def crop_img_bboxes(self, img, bboxes):
        '''
        裁剪后的图片要包含所有的框
        输入:
            img:图像array
            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        输出:
            crop_img:裁剪后的图像array
            crop_bboxes:裁剪后的BBox的坐标list
        '''
        # ---------------------- 裁剪图像 ------------------------------
        width = img.shape[1]
        height = img.shape[0]
        x_min = width
        x_max = 0
        y_min = height
        y_max = 0
        # 裁剪后的包含所有目标框的最小的框
        for bbox in bboxes:
            x_min = min(x_min, bbox[0])
            y_min = min(y_min, bbox[1])
            x_max = max(x_max, bbox[2])
            y_max = max(y_max, bbox[3])

        d_to_left = x_min  # 包含所有目标框的最小框到左边的距离
        d_to_right = width - x_max  # 包含所有目标框的最小框到右边的距离
        d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离
        d_to_bottom = height - y_max  # 包含所有目标框的最小框到底部的距离

        # 随机扩展这个最小框
        crop_x_min = int(x_min - random.uniform(0, d_to_left))
        crop_y_min = int(y_min - random.uniform(0, d_to_top))
        crop_x_max = int(x_max + random.uniform(0, d_to_right))
        crop_y_max = int(y_max + random.uniform(0, d_to_bottom))

        # 随机扩展这个最小框 , 防止别裁的太小
        # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))
        # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))
        # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))
        # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))

        # 确保不要越界
        crop_x_min = max(0, crop_x_min)
        crop_y_min = max(0, crop_y_min)
        crop_x_max = min(width, crop_x_max)
        crop_y_max = min(height, crop_y_max)

        crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]

        # ---------------------- 裁剪BBox-----------------------
        # 裁剪后的BBox坐标计算
        crop_bboxes = list()
        for bbox in bboxes:
            crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min,bbox[4],bbox[4]])

        return crop_img, crop_bboxes

    # 平移
    def shift_pic_bboxes(self, img, bboxes):
        '''
        平移后的图片要包含所有的框
        输入:
            img:图像array
            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        输出:
            shift_img:平移后的图像array
            shift_bboxes:平移后的BBox的坐标list
        '''
        # ---------------------- 平移图像 ---------------------------
        width = img.shape[1]
        height = img.shape[0]
        x_min = width
        x_max = 0
        y_min = height
        y_max = 0
        for bbox in bboxes:
            x_min = min(x_min, bbox[0])  # bbox的x最小值小于width,x最大值大于0
            y_min = min(y_min, bbox[1])
            x_max = max(x_max, bbox[2])
            y_max = max(y_max, bbox[3])
        # x_min是所有目标框的x的最小值,y_min是所有目标框的y的最小值
        # x_max是所有目标框的x的最大值,y_max是所有目标框的y的最大值
        d_to_left = x_min  # 包含所有目标框的最大左移动距离
        d_to_right = width - x_max  # 包含所有目标框的最大右移动距离
        d_to_top = y_min  # 包含所有目标框的最大上移动距离
        d_to_bottom = height - y_max  # 包含所有目标框的最大下移动距离

        x = random.uniform(-(d_to_left - 1) / 3, (d_to_right - 1) / 3)
        y = random.uniform(-(d_to_top - 1) / 3, (d_to_bottom - 1) / 3)
        # x为向左或右移动的像素值,正为向右,负为向左;
        # y为向上或者向下移动的像素值,正为向下,负为向上
        M = np.float32([[1, 0, x], [0, 1, y]])
        shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))

        # ---------------------- 平移BBox ----------------------
        shift_bboxes = list()
        for bbox in bboxes:
            shift_bboxes.append([bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y,bbox[4]])

        return shift_img, shift_bboxes

    # 镜像
    def filp_pic_bboxes(self, img, bboxes):
        '''
        输入:
            img:图像array
            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        输出:
            flip_img:翻转后的图像array
            flip_bboxes:翻转后的BBox的坐标list
        '''
        # ---------------------- 翻转图像 ----------------------
        flip_img = copy.deepcopy(img)
        if random.random() < 0.5:  # 0.5的概率水平翻转,0.5的概率垂直翻转
            horizon = True
        else:
            horizon = False
        height, width, _ = img.shape
        if horizon:  # 水平翻转
            flip_img = cv2.flip(flip_img, 1)  # 1是水平,-1是水平垂直
        else:
            flip_img = cv2.flip(flip_img, 0)
        # ---------------------- 调整BBox ----------------------
        flip_bboxes = list()
        for box in bboxes:
            x_min = box[0]
            y_min = box[1]
            x_max = box[2]
            y_max = box[3]
            if horizon:
                flip_bboxes.append([width - x_max, y_min, width - x_min, y_max,box[4]])
            else:
                flip_bboxes.append([x_min, height - y_max, x_max, height - y_min,box[4]])

        return flip_img, flip_bboxes

    # 旋转
    def rotate_img_bbox(self, img, bboxes, piangle=5, scale=1.):
        '''
        旋转后的图片需要包含所有的框,否则会对图像的原始标注造成破坏。
        需要注意的是,旋转时图像的一些边角可能会被切除掉,需要避免这种情况。
        关于仿射变换:
        输入:
            img:图像array,(h,w,c)
            bboxes:该图像包含的所有BBox,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
            angle:旋转角度
            scale:默认1
        输出:
            rot_img:旋转后的图像array
            rot_bboxes:旋转后的BBox坐标list

        '''
        # ---------------------- 旋转图像 -----------------------------------
        angle = -piangle * math.pi / 180.0
        rows, cols = img.shape[:2]
        a, b = cols / 2, rows / 2
        M = cv2.getRotationMatrix2D((a, b), piangle, 1)
        # img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
        rotated_img = cv2.warpAffine(img, M, (cols, rows))  # 旋转后的图像保持大小不变

        # ---------------------- 矫正bbox坐标 ------------------------------
        # rot_mat是最终的旋转矩阵
        # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
        rot_bboxes = list()
        for bbox in bboxes:
            x1 = float(bbox[0])-1
            y1 = float(bbox[1])-1
            x2 = float(bbox[2])-1
            y2 = float(bbox[3])-1

            x3 = x1
            y3 = y2
            x4 = x2
            y4 = y1

            X1 = (x1 - a) * math.cos(angle) - (y1 - b) * math.sin(angle) + a
            Y1 = (x1 - a) * math.sin(angle) + (y1 - b) * math.cos(angle) + b

            X2 = (x2 - a) * math.cos(angle) - (y2 - b) * math.sin(angle) + a
            Y2 = (x2 - a) * math.sin(angle) + (y2 - b) * math.cos(angle) + b

            X3 = (x3 - a) * math.cos(angle) - (y3 - b) * math.sin(angle) + a
            Y3 = (x3 - a) * math.sin(angle) + (y3 - b) * math.cos(angle) + b

            X4 = (x4 - a) * math.cos(angle) - (y4 - b) * math.sin(angle) + a
            Y4 = (x4 - a) * math.sin(angle) + (y4 - b) * math.cos(angle) + b

            NEW_X1 = (X1 + X3) / 2
            NEW_X2 = (X2 + X3) / 2
            NEW_X3 = (X2 + X4) / 2
            NEW_X4 = (X1 + X4) / 2

            NEW_Y1 = (Y1 + Y3) / 2
            NEW_Y2 = (Y2 + Y3) / 2
            NEW_Y3 = (Y2 + Y4) / 2
            NEW_Y4 = (Y1 + Y4) / 2

            # X_MIN = min(X1, X2, X3, X4)
            # X_MAX = max(X1, X2, X3, X4)
            # Y_MIN = min(Y1, Y2, Y3, Y4)
            # Y_MAX = max(Y1, Y2, Y3, Y4)

            X_MIN = min(NEW_X1, NEW_X2, NEW_X3, NEW_X4)
            X_MAX = max(NEW_X1, NEW_X2, NEW_X3, NEW_X4)
            Y_MIN = min(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)
            Y_MAX = max(NEW_Y1, NEW_Y2, NEW_Y3, NEW_Y4)

            # if X1>X2:
            #     X1,X2 = X2,X1
            #     Y1,Y2 = Y2,Y1
            #
            # # 求中心点
            # X_c = float(X1 + X2) / 2
            # Y_c = float(Y1 + Y2) / 2
            # H = Y2 - Y1
            # W = X2 - X1
            #
            # X_MIN = X_c - W/2
            # Y_MIN = Y_c - H/2
            # X_MAX = X_c + W/2
            # Y_MAX = Y_c + H/2


            # 加入list中
            rot_bboxes.append([X_MIN, Y_MIN, X_MAX, Y_MAX,bbox[4]])

        return rotated_img, rot_bboxes

    # 遮挡,擦除
    def erase(self, img, bboxes, length=100, n_holes=1, threshold=0.5):
        '''
        Randomly mask out one or more patches from an image.
        从图像中随机遮罩一个或多个面片。
        Args:
            img: a 3D numpy array,(h,w,c)
            bboxes: 框的坐标
            n_holes(int): Number of patches to cut out of each image.
            length(int): The length (in pixels) of each square patch.
        '''

        def cal_iou(boxA, boxB):
            '''
            输入:boxA, boxB
            输出:返回iou
            '''
            # 确定相交矩形的(x, y)-坐标
            xA = max(boxA[0], boxB[0])
            yA = max(boxA[1], boxB[1])
            xB = min(boxA[2], boxB[2])
            yB = min(boxA[3], boxB[3])
            if xB <= xA or yB <= yA:
                return 0.0
            # 计算相交矩形的面积
            interArea = (xB - xA + 1) * (yB - yA + 1)
            # 计算prediction和ground-truth的面积
            boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
            boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
            # iou = interArea / float(boxAArea + boxBArea - interArea)
            iou = interArea / float(boxBArea)

            return iou

        # 得到h和w
        if img.ndim == 3:
            h, w, c = img.shape
        else:
            _, h, w, c = img.shape

        mask = np.ones((h, w, c), np.float32)

        for n in range(n_holes):
            overlap = True  # 看切割的区域是否与box重叠太多
            while overlap:
                y = np.random.randint(h)
                x = np.random.randint(w)
                # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,
                # 大于a_max的就使得它等于a_max,小于a_min的就使得它等于a_min
                x1 = np.clip(x - length // 2, 0, w)
                x2 = np.clip(x + length // 2, 0, w)
                y1 = np.clip(y - length // 2, 0, h)
                y2 = np.clip(y + length // 2, 0, h)
                overlap = False
                for box in bboxes:
                    if cal_iou([x1, y1, x2, y2], box) > threshold:
                        overlap = True
                        break
            mask[y1:y2, x1:x2, :] = 0.
        # mask = np.expand_dims(mask, axis=0)
        erase_img = img * mask

        return erase_img, bboxes

    def dataAugment(self, img, bboxes):
        '''
        图像增强
        输入:
            img:图像array
            bboxes:该图像的所有框坐标
        输出:
            img:增强后的图像
            bboxes:增强后图片对应的box
        '''
        change_num = random.sample(range(-6,0), 1)[0]  # 改变的次数
        print('------------------开始进行数据增强-------------------')
        while change_num < 1:  # 默认至少有一种数据增强生效
            if random.random() < self.add_noise_rate:  # 加噪声
                print('加噪声')
                change_num += 1
                img = self.gasuss_noise(img)
                print(img.shape)

            if random.random() > self.change_light_rate:  # 改变亮度
                print('亮度')
                change_num += 1
                img = self.changeLight(img)
                print(img.shape)

            if random.random() < self.shift_rate:  # 平移
                print('平移')
                change_num += 1
                img, bboxes = self.shift_pic_bboxes(img, bboxes)
                print(img.shape)

            if random.random() < self.flip_rate:  # 镜像
                print('镜像')
                change_num += 1
                img, bboxes = self.filp_pic_bboxes(img, bboxes)
                print(img.shape)

            if random.random() > self.rotation_rate:  # 旋转
                print('旋转')
                change_num += 1
                # angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)
                angle = random.sample(range(-180,180), 1)[0]
                # angle = -45
                print(angle)
                scale = random.uniform(0.7, 0.8)
                img, bboxes = self.rotate_img_bbox(img, bboxes, angle, scale)

            if random.random() < self.erase_rate:  # 遮挡,擦除
                print('遮挡,擦除')
                change_num += 1
                erase_length = random.sample(range(0,50),1)[0]
                erase_holes = random.sample(range(0,5),1)[0]
                img, bboxes = self.erase(img, bboxes, length=erase_length, n_holes=erase_holes,
                                 threshold=self.erase_threshold)
                print(img.shape)

            print('\n')
        print('------------------结束进行数据增强-------------------')
        print(bboxes)
        return img, bboxes


def xyxy2xywh(x):
    dw = 1. / (608)
    dh = 1. / (608)
    # if w >= 1:
    #     w = 0.99
    # if h >= 1:
    #     h = 0.99
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    print(x.shape)
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2 -1)*dw # x center
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2  -1)*dw# y center
    y[:, 2] = (x[:, 2] - x[:, 0])*dh  # width
    y[:, 3] = (x[:, 3] - x[:, 1])*dh  # height
    return y

if __name__ == '__main__':
    import shutil
    from DOC import *

    need_aug_num = 10
    dataAug = DataAugmentForObjectDetection()
    source_pic_root_path = r'C:\Users\JKY\Desktop\***' #your image folders
    source_xml_root_path = r'C:\Users\JKY\Desktop\***' # your xml path
    save_images = r'C:\Users\JKY\Desktop\DataAugmentation_ForObjectDetect-master\dataAugmentation-images' # image save path
    save_labels = r'C:\Users\JKY\Desktop\DataAugmentation_ForObjectDetect-master\dataAugmentation-label' # label save path
    for parent, _, files in os.walk(source_pic_root_path):
        for file in files:
            cnt = 0
            while cnt < need_aug_num:
                try:
                    pic_path = os.path.join(parent, file)
                    xml_path = os.path.join(source_xml_root_path, file[:-4] + '.xml')
                    if not os.path.exists(xml_path):
                        cnt += 1
                        continue
                    coords = parse_xml(xml_path)  # 解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]]
                    # coords = [coord[:4] for coord in coords]
                    print(pic_path)
                    # img = cv2.imread(pic_path)
                    img = Image.open(pic_path)
                    img = np.array(img)
                    # show_pic(img, coords)  # 原图

                    auged_img, auged_bboxes = dataAug.dataAugment(img, coords)
                    img = cv2.cvtColor(auged_img, cv2.COLOR_RGB2BGR)
                    cnt += 1
                    name = str(int(time.time() * 1e5))
                    cv2.imwrite(os.path.join(save_images,name+'.jpg'),img)
                    print('image save success')
                    txt_path = os.path.join(save_labels,name+'.txt')

                    auged_bboxes = np.float32(np.array(auged_bboxes))
                    xywh = xyxy2xywh(auged_bboxes).astype(np.str).tolist()
                    res_bboxes = []
                    for i in range(len(xywh)):
                        index = xywh[i][-1].index('.')
                        xywhs = [xywh[i][-1][:index]] + xywh[i][:4]
                        print(xywhs)
                        res_bboxes.append(xywhs)

                    res = []
                    print(res_bboxes)
                    for listi in res_bboxes:
                        stri = ' '.join(listi)
                        stri += '\n'
                        res.append(stri)
                    print(res_bboxes)
                    with open(txt_path,'w+',-1) as file:
                        file.writelines(res)
                    file.close()
                    # show_pic(auged_img, auged_bboxes)  # 数据增强后的图
                except:
                    cnt += 1
                    continue

  • 8
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值