基于Albumentations库的离线数据增强(支持在线和离线)

测试安装

# 首先进行图片增强小测试
# 首先进行图片增强小测试, 该测试只是选择一下增强的方式, 例如: 如果你的检测目标与颜色有关联, 
# 可能就不能选择改变颜色的增强方式, 如果采用镜像的增强方式, 左转路标可能就变成右转路标, 需要注意!!!
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 读取原始图片
original_image = cv2.imread('/home/aistudio/work/0000.jpg')

# 像素级变换
transform_Pixel = A.Compose([
    # A.CLAHE(p=1),  # 直方图均衡
    # A.ChannelDropout(p=1),  # 随机丢弃通道
    # A.ChannelShuffle(p=1),  # 随机排列通道
    A.ColorJitter(p=1),  # 随机改变图像的亮度、对比度、饱和度、色调
])

# 空间级变换
transform_Spatial = A.Compose([
    # A.RandomCrop(width=256, height=256),
    A.HorizontalFlip(p=1),
    A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=1), # 与像素级变换结合使用
    # A.SafeRotate(limit=60, p=1),
    # A.Rotate(limit=45, p=1),
    # A.Affine(p=1),
    # A.GridDistortion(p=1),

])

# 进行增强变化
transformed = transform_Spatial(image=original_image)

# 获得增强后的图片
transformed_image = transformed["image"]
transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_BGR2RGB)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

plt.subplot(1, 2, 1), plt.title("original image"), plt.axis('off')
plt.imshow(original_image) 
plt.subplot(1, 2, 2), plt.title("transformed image"), plt.axis('off')
plt.imshow(transformed_image)

plt.show()

COCO格式数据增强

coco格式如下:

COCO
 |-- annotations
  |-- train.json
  |-- val.json

 |-- train
  |-- 0000.jpg
  |-- 0001.jpg
  |-- …jpg

 |-- val
  |-- 0000.jpg
  |-- 0001.jpg
  |-- …jpg
# 定义增强类
class COCOAug(object):


    def __init__(self,
                 anno_path=None,
                 pre_image_path=None,
                 save_image_path=None,
                 anno_mode='train',
                 is_show=True,
                 start_filename_id=None,
                 start_anno_id=None,
                 ):
        """

        :param anno_path: json文件的路径
        :param pre_image_path: 需要增强的图片路径
        :param save_image_path: 保存的图片路径
        :param anno_mode: 有train,val两种, 同时也对应两种路径, 两种json文件[train.json, val.json]
        :param is_show: 是否实时展示: 每增强一张图片就把对应的标注框和标签画出并imshow
        :param start_filename_id: 新的图片起始名称. 同时也对应图片的id, 后续在此基础上依次+1,
                                  如果没有指定则按已有的图片长度继续+1
        :param start_anno_id: 新的注释id起始号, 后续在此基础上依次+1, 如果没有指定则按已有的注释个数长度继续+1
        """
        self.anno_path = anno_path
        self.aug_image_path = pre_image_path
        self.save_image_path = save_image_path
        self.anno_mode = anno_mode
        self.is_show = is_show
        self.start_filename_id = start_filename_id
        self.start_anno_id = start_anno_id

        # 数据增强选项
        self.aug = A.Compose([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
            A.GaussianBlur(p=0.7), # 高斯滤波
            A.GaussNoise(p=0.7), # 高斯模糊
            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡
            A.Equalize(p=0.5),  # 均衡图像直方图
            A.HorizontalFlip(p=1), 
            A.OneOf([
                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
                # A.ChannelShuffle(p=0.3),  # 随机排列通道
                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调
                # A.ChannelDropout(p=0.3),  # 随机丢弃通道
            ], p=0.),
            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
        ],
            # coco: [x_min, y_min, width, height]
            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            A.BboxParams(format='coco', min_area=0., min_visibility=0., label_fields=['category_id'])
        )

        # 打开json文件
        with open(os.path.join(self.anno_path, f"{self.anno_mode}.json"), 'r', encoding='utf-8') as load_f:
            self.load_dict = json.load(load_f)  # ['images', 'annotations', 'categories']

            self.labels = []  # 读取标签列表
            for anno in self.load_dict['categories']:
                self.labels.append(anno['name'])

            print("--------- * ---------")
            if self.start_filename_id is None:
                self.start_filename_id = len(self.load_dict['images'])
                print("the start_filename_id is not set, default: len(images)")
            if self.start_anno_id is None:
                self.start_anno_id = len(self.load_dict['annotations'])
                print("the start_anno_id is not set, default: len(annotations)")
            print("len(images)     : ", self.start_filename_id)
            print("len(annotations): ", self.start_anno_id)
            print("categories: ", self.load_dict['categories'])
            print("labels: ", self.labels)
            print("--------- * ---------")
    
    def image_aug(self, max_len=4):
        """
        json格式
        "images": [{"file_name": "013856.jpg", "height": 1080, "width": 1920, "id": 13856},...]
        "annotations": [{"image_id": 13856, "id": 0, "category_id": 2, "bbox": [541, 517, 79, 102],
                         "area": 8058, "iscrowd": 0, "segmentation": []}, ...]
        "categories": [{"id": 0, "name": "Motor Vehicle"}, ...]


        :param start_filename_id: 起始图片id号
        :param start_anno_id: 起始注释框id号
        :param max_len: 默认数据集不超过9999, 即: 0000~9999 如果更多可以设置为5 即00000~99999

        :return: None
        """
        # 保存原始数据
        aug_data = self.load_dict

        # 记录给定的开始序列
        cnt_filename = self.start_filename_id
        cnt_anno_id = self.start_anno_id

        # 对每一张图片遍历
        for index, item in enumerate(self.load_dict['images'][:]):
            image_name = item['file_name']
            image_suffix = image_name.split(".")[-1]  # 获取图片后缀 e.g. [.jpg .png]
            image_id = item['id']

            bboxes_list = []
            category_id_list = []
            # 对每一张图片找到所有的标注框, 并且bbox和label的id要对应上
            for anno in self.load_dict['annotations']:
                if anno['image_id'] == image_id:
                    bboxes_list.append(anno['bbox'])
                    category_id_list.append(anno['category_id'])
            # 读取图片
            image = cv2.imread(os.path.join(self.aug_image_path, image_name))
            h, w = image.shape[:2]
            # 生成需要增强的图片的anno字典
            # augmented {'image':, 'height':,'width:', 'bboxes':[(),()], 'category_id':[,,]}
            aug_anno = {'image': image, 'height': h, 'width': w, 'bboxes': bboxes_list, 'category_id': category_id_list}

            # 得到增强后的数据 {"image", "height", "width", "bboxes", "category_id"}
            augmented = self.aug(**aug_anno)
            # print(augmented)
            aug_image = augmented['image']
            aug_bboxes = augmented['bboxes']
            aug_category_id = augmented['category_id']
            height = augmented['height']
            width = augmented['width']

            # 对增强后的bbox取整
            for index, bbox in enumerate(aug_bboxes):
                x, y, w, h = bbox
                aug_bboxes[index] = [int(x + 0.5), int(y + 0.5), int(w + 0.5), int(h + 0.5)]

            # 是否进行实时展示图片, 用于检测是否有误
            if self.is_show:
                tl = 2
                # aug_image_copy = aug_image.copy()
                aug_image_copy = aug_image
                for bbox, category_id in zip(aug_bboxes, aug_category_id):
                    text = f"{self.labels[category_id]}"
                    t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]
                    cv2.rectangle(aug_image_copy, (bbox[0], bbox[1] - 3),
                                  (bbox[0] + t_size[0], bbox[1] - t_size[1] - 3),
                                  (0, 0, 255), -1, cv2.LINE_AA)  # filled
                    cv2.putText(aug_image_copy, text, (bbox[0], bbox[1] - 2), 0, tl / 3, (255, 255, 255), tl,
                                cv2.LINE_AA)
                    aug_image_show = cv2.rectangle(aug_image_copy, (bbox[0], bbox[1]),
                                                   (bbox[0] + bbox[2], bbox[1] + bbox[3]),
                                                   (255, 255, 0), 2)

                # cv2.imshow('aug_image_show', aug_image_show)
                
                # 实时检测增强后的标注框是否有较大偏差, 符合要求按下's'健保存, 其他键跳过
                key = cv2.waitKey(0)
                # 按下s键保存增强,否则取消保存此次增强
                if key & 0xff == ord('s'):
                    pass
                else:
                    cv2.destroyWindow(f'aug_image_show')
                    continue
                cv2.destroyWindow(f'aug_image_show')


            # 获取新的图片名称 e.g.  cnt_filename=45   new_filename: 0045.image_suffix
            name = '0' * max_len  # e.g. '0'*4 = '0000'
            cnt_str = str(cnt_filename)
            length = len(cnt_str)
            new_filename = name[:-length] + cnt_str + f'.{image_suffix}'
            # 保存增强后的图片
            cv2.imwrite(os.path.join(self.save_image_path, new_filename), aug_image)
            # 添加增强后的图片
            dict_image = {
                "file_name": new_filename,
                "height": height,
                "width": width,
                "id": cnt_filename
            }
            aug_data['images'].append(dict_image)

            # print("augmented['bboxes']: ", augmented['bboxes'])
            for bbox, idx in zip(bboxes_list, category_id_list):
                dict_anno = {'image_id': cnt_filename,
                             'id': cnt_anno_id,
                             'category_id': idx,
                             'bbox': bbox,
                             'area': int(bbox[2] * bbox[3]),
                             'iscrowd': 0,
                             "segmentation": []
                             }
                aug_data['annotations'].append(dict_anno)

                # 每一个增加的anno_id+1
                cnt_anno_id += 1

            # 图片数+1
            cnt_filename += 1

        # 保存增强后的json文件
        with open(os.path.join(self.anno_path, f'aug_{self.anno_mode}.json'), 'w') as ft:
            json.dump(aug_data, ft)
# 对示例数据集进行增强, 运行成功后会在相应目录下保存
import os
import json
import cv2

# 图片路径
PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/COCO/val'
SAVE_IMAGE_PATH = '/home/aistudio/work/TestImage/COCO/val'

# anno路径
ANNO_PATH = '/home/aistudio/work/TestImage/COCO/annotations'
mode = 'val'  # ['train', 'val']

aug = COCOAug(
        anno_path=ANNO_PATH,
        pre_image_path=PRE_IMAGE_PATH,
        save_image_path=SAVE_IMAGE_PATH,
        anno_mode=mode,
        is_show=False,
    )

aug.image_aug()

cv2.destroyAllWindows()

YOLO格式数据增强

yolo格式如下:

YOLO
 |-- images
  |-- 0000.jpg
  |-- 0001.jpg
  |-- …jpg

 |-- labels
  |-- 0000.txt
  |-- 0001.txt
  |-- …txt
# 定义类
class YOLOAug(object):


    def __init__(self,
                 pre_image_path=None,
                 pre_label_path=None,
                 aug_save_image_path=None,
                 aug_save_label_path=None,
                 labels=None,
                 is_show=True,
                 start_filename_id=None,
                 max_len=4):
        """
        
        :param pre_image_path: 
        :param pre_label_path: 
        :param aug_save_image_path: 
        :param aug_save_label_path: 
        :param labels: 标签列表, 需要根据自己的设定, 用于展示图片
        :param is_show: 
        :param start_filename_id: 
        :param max_len: 
        """
        self.pre_image_path = pre_image_path
        self.pre_label_path = pre_label_path
        self.aug_save_image_path = aug_save_image_path
        self.aug_save_label_path = aug_save_label_path
        self.labels = labels
        self.is_show = is_show
        self.start_filename_id = start_filename_id
        self.max_len = max_len
        # 数据增强选项
        self.aug = A.Compose([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
            A.GaussianBlur(p=0.7),
            A.GaussNoise(p=0.7),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡
            A.Equalize(p=0.5),  # 均衡图像直方图
            A.OneOf([
                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
                # A.ChannelShuffle(p=0.3),  # 随机排列通道
                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调
                # A.ChannelDropout(p=0.3),  # 随机丢弃通道
            ], p=0.),
            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
        ],
            # yolo: [x_center, y_center, width, height]  # 经过归一化
            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            A.BboxParams(format='yolo', min_area=0., min_visibility=0., label_fields=['category_id'])
        )
        print("--------*--------")
        image_len = len(os.listdir(self.pre_image_path))
        print("the length of images: ", image_len)
        if self.start_filename_id is None:
            print("the start_filename id is not set, default: len(image)", image_len)
            self.start_filename_id = image_len

        print("--------*--------")


    def get_data(self, image_name):
        """
        获取图片和对应的label信息

        :param image_name: 图片文件名, e.g. 0000.jpg
        :return:
        """
        image = cv2.imread(os.path.join(self.pre_image_path, image_name))

        if len(image_name.split('.')[0]) == 0:
            return None

        with open(os.path.join(self.pre_label_path, image_name.split('.')[0] + '.txt'), 'r',encoding='utf-8') as f:
            label_txt = f.readlines()

        label_list = []
        cls_id_list = []
        for label in label_txt:
            label_info = label.strip().split(' ')
            cls_id_list.append(int(label_info[0]))
            label_list.append([float(x) for x in label_info[1:]])

        anno_info = {'image': image, 'bboxes': label_list, 'category_id': cls_id_list}
        return anno_info


    def aug_image(self):
        image_list = os.listdir(self.pre_image_path)

        file_name_id = self.start_filename_id
        for image_filename in image_list[:]:
            image_suffix = image_filename.split('.')[-1]
            if image_suffix not in ['jpg', 'png']:
                continue
            image_suffix = image_filename.split('.')[-1]

            aug_anno = self.get_data(image_filename)
            if aug_anno is None:
                continue

            # 获取增强后的信息
            augmented = self.aug(**aug_anno)  # {'image': , 'bboxes': , 'category_id': }
  
            aug_image = aug_info['image']
            aug_bboxes = aug_info['bboxes']
            aug_category_id = aug_info['category_id']

            name = '0' * self.max_len
            cnt_str = str(file_name_id)
            length = len(cnt_str)
            new_image_filename = name[:-length] + cnt_str + f'.{image_suffix}'
            new_label_filename = name[:-length] + cnt_str + '.txt'
            print(f"aug_image_{new_image_filename}: ")

            aug_image_copy = aug_image.copy()
            for cls_id, bbox in zip(aug_category_id, aug_bboxes):
                print(f" --- --- cls_id: ", cls_id)

                if self.is_show:
                    tl = 2
                    h, w = aug_image_copy.shape[:2]
                    x_center = int(bbox[0] * w)
                    y_center = int(bbox[1] * h)
                    width = int(bbox[2] * w)
                    height = int(bbox[3] * h)
                    xmin = int(x_center - width / 2)
                    ymin = int(y_center - height / 2)
                    xmax = int(x_center + width / 2)
                    ymax = int(y_center + height / 2)
                    text = f"{self.labels[cls_id]}"
                    t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]
                    cv2.rectangle(aug_image_copy, (xmin, ymin - 3), (xmin + t_size[0], ymin - t_size[1] - 3), (0, 0, 255),
                                  -1, cv2.LINE_AA)  # filled
                    cv2.putText(aug_image_copy, text, (xmin, ymin - 2), 0, tl / 3, (255, 255, 255), tl, cv2.LINE_AA)
                    aug_image_show = cv2.rectangle(aug_image_copy, (xmin, ymin), (xmax, ymax), (255, 255, 0), 2)

            if self.is_show:
                cv2.imshow(f'aug_image_{new_image_filename}', aug_image_show)
                key = cv2.waitKey(0)
                # 按下s键保存增强,否则取消保存此次增强
                if key & 0xff == ord('s'):
                    pass
                else:
                    cv2.destroyWindow(f'aug_image_{new_image_filename}')
                    continue
                cv2.destroyWindow(f'aug_image_{new_image_filename}')
                
           # 保存增强后的信息
            cv2.imwrite(os.path.join(self.aug_save_image_path, new_image_filename), aug_image)
            with open(os.path.join(self.aug_save_label_path, new_label_filename), 'w', encoding='utf-8') as lf:
                for cls_id, bbox in zip(aug_category_id, aug_bboxes):
                    lf.write(str(cls_id) + ' ')
                    for i in bbox:
                        # 保存小数点后六位
                        lf.write(str(i)[:8] + ' ')
                    lf.write('\n')

            file_name_id += 1
# 对示例数据集进行增强, 运行成功后会在相应目录下保存 
import os
import json
import cv2
import numpy as np

# 原始图片和label路径
PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/YOLO/images'
PRE_LABEL_PATH = '/home/aistudio/work/TestImage/YOLO/labels'

# 增强后的图片和label保存的路径
AUG_SAVE_IMAGE_PATH ='/home/aistudio/work/TestImage/YOLO/images'
AUG_SAVE_LABEL_PATH = '/home/aistudio/work/TestImage/YOLO/labels'

# 类别列表, 需要根据自己的修改
labels = ['side-walk', 'speed-limit', 'turn-left', 'slope', 'speed']

aug = YOLOAug(pre_image_path=PRE_IMAGE_PATH,
                pre_label_path=PRE_LABEL_PATH,
                aug_save_image_path=AUG_SAVE_IMAGE_PATH,
                aug_save_label_path=AUG_SAVE_LABEL_PATH,
                labels=labels,
                is_show=False)
aug.get_aug_data()

VOC格式数据增强

voc格式如下:

VOC
 |-- images
  |-- 0000.jpg
  |-- 0001.jpg
  |-- …jpg

 |-- labels
  |-- 0000.xml
  |-- 0001.xml
  |-- …xml
# 定义类
class VOCAug(object):


    def __init__(self,
                 pre_image_path=None,
                 pre_xml_path=None,
                 aug_image_save_path=None,
                 aug_xml_save_path=None,
                 start_aug_id=None,
                 labels=None,
                 max_len=4,
                 is_show=False):
        """
        
        :param pre_image_path: 
        :param pre_xml_path: 
        :param aug_image_save_path: 
        :param aug_xml_save_path: 
        :param start_aug_id: 
        :param labels: 标签列表, 展示增强后的图片用
        :param max_len: 
        :param is_show: 
        """
        self.pre_image_path = pre_image_path
        self.pre_xml_path = pre_xml_path
        self.aug_image_save_path = aug_image_save_path
        self.aug_xml_save_path = aug_xml_save_path
        self.start_aug_id = start_aug_id
        self.labels = labels
        self.max_len = max_len
        self.is_show = is_show

        print(self.labels)
        assert self.labels is not None, "labels is None!!!"

        # 数据增强选项
        # 数据增强选项
        self.aug = A.Compose([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
            A.GaussianBlur(p=0.7),
            A.GaussNoise(p=0.7),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡
            A.Equalize(p=0.5),  # 均衡图像直方图
            A.OneOf([
                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
                # A.ChannelShuffle(p=0.3),  # 随机排列通道
                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调
                # A.ChannelDropout(p=0.3),  # 随机丢弃通道
            ], p=0.),
            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
        ],
            # voc: [xmin, ymin, xmax, ymax]  # 经过归一化
            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            A.BboxParams(format='pascal_voc', min_area=0., min_visibility=0., label_fields=['category_id'])
        )
        print('--------------*--------------')
        print("labels: ", self.labels)
        if self.start_aug_id is None:
            self.start_aug_id = len(os.listdir(self.pre_xml_path))
            print("the start_aug_id is not set, default: len(images)", self.start_aug_id)
        print('--------------*--------------')


    def get_xml_data(self, xml_filename):
        with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:
            tree = ET.parse(f)
            root = tree.getroot()
            image_name = tree.find('filename').text
            size = root.find('size')
            w = int(size.find('width').text)
            h = int(size.find('height').text)
            bboxes = []
            cls_id_list = []
            for obj in root.iter('object'):
                # difficult = obj.find('difficult').text
                difficult = obj.find('difficult').text
                cls_name = obj.find('name').text  # label
                if cls_name not in LABELS or int(difficult) == 1:
                    continue
                xml_box = obj.find('bndbox')

                xmin = int(xml_box.find('xmin').text)
                ymin = int(xml_box.find('ymin').text)
                xmax = int(xml_box.find('xmax').text)
                ymax = int(xml_box.find('ymax').text)

                # 标注越界修正
                if xmax > w:
                    xmax = w
                if ymax > h:
                    ymax = h
                bbox = [xmin, ymin, xmax, ymax]
                bboxes.append(bbox)
                cls_id_list.append(self.labels.index(cls_name))

            # 读取图片
            image = cv2.imread(os.path.join(self.pre_image_path, image_name))

        return bboxes, cls_id_list, image, image_name


    def aug_image(self):
        xml_list = os.listdir(self.pre_xml_path)

        cnt = self.start_aug_id
        for xml in xml_list:
            file_suffix = xml.split('.')[-1]
            if file_suffix not in ['xml']:
                continue
                
            bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)

            anno_dict = {'image': image, 'bboxes': bboxes, 'category_id': cls_id_list}
            # 获得增强后的数据 {"image", "bboxes", "category_id"}
            augmented = self.aug(**anno_dict)

            # 保存增强后的数据
            flag = self.save_aug_data(augmented, image_name, cnt)

            if flag:
                cnt += 1
            else:
                continue


    def save_aug_data(self, augmented, image_name, cnt):
        aug_image = augmented['image']
        aug_bboxes = augmented['bboxes']
        aug_category_id = augmented['category_id']
        # print(aug_bboxes)
        # print(aug_category_id)

        name = '0' * self.max_len
        # 获取图片的后缀名
        image_suffix = image_name.split(".")[-1]

        # 未增强对应的xml文件名
        pre_xml_name = image_name.replace(image_suffix, 'xml')

        # 获取新的增强图像的文件名
        cnt_str = str(cnt)
        length = len(cnt_str)
        new_image_name = name[:-length] + cnt_str + "." + image_suffix

        # 获取新的增强xml文本的文件名
        new_xml_name = new_image_name.replace(image_suffix, 'xml')

        # 获取增强后的图片新的宽和高
        new_image_height, new_image_width = aug_image.shape[:2]

        # 深拷贝图片
        aug_image_copy = aug_image.copy()

        # 在对应的原始xml上进行修改, 获得增强后的xml文本
        with open(os.path.join(self.pre_xml_path, pre_xml_name), 'r') as pre_xml:
            aug_tree = ET.parse(pre_xml)

        # 修改image_filename值
        root = aug_tree.getroot()
        aug_tree.find('filename').text = new_image_name

        # 修改变换后的图片大小
        size = root.find('size')
        size.find('width').text = str(new_image_width)
        size.find('height').text = str(new_image_height)

        # 修改每一个标注框
        for index, obj in enumerate(root.iter('object')):
            obj.find('name').text = self.labels[aug_category_id[index]]
            xmin, ymin, xmax, ymax = aug_bboxes[index]
            xml_box = obj.find('bndbox')
            xml_box.find('xmin').text = str(int(xmin))
            xml_box.find('ymin').text = str(int(ymin))
            xml_box.find('xmax').text = str(int(xmax))
            xml_box.find('ymax').text = str(int(ymax))
            if self.is_show:
                tl = 2
                text = f"{LABELS[aug_category_id[index]]}"
                t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]
                cv2.rectangle(aug_image, (int(xmin), int(ymin) - 3),
                              (int(xmin) + t_size[0], int(ymin) - t_size[1] - 3),
                              (0, 0, 255), -1, cv2.LINE_AA)  # filled
                cv2.putText(aug_image, text, (int(xmin), int(ymin) - 2), 0, tl / 3, (255, 255, 255), tl,
                            cv2.LINE_AA)
                cv2.rectangle(aug_image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)

        if self.is_show:
            cv2.imshow('aug_image_show', aug_image_copy)
            # 按下s键保存增强,否则取消保存此次增强
            key = cv2.waitKey(0)
            if key & 0xff == ord('s'):
                pass
            else:
                return False
        # 保存增强后的图片
        cv2.imwrite(os.path.join(self.aug_image_save_path, new_image_name), aug_image)
        # 保存增强后的xml文件
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))
        
        return True
import os
import cv2

import albumentations as A
import xml.etree.ElementTree as ET

# 原始的xml路径和图片路径
PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/VOC/images'
PRE_XML_PATH = '/home/aistudio/work/TestImage/VOC/labels'

# 增强后保存的xml路径和图片路径
AUG_SAVE_IMAGE_PATH ='/home/aistudio/work/TestImage/VOC/images'
AUG_SAVE_XML_PATH = '/home/aistudio/work/TestImage/VOC/labels'

# 标签列表
LABELS = ['zu', 'pai', 'lan']

aug = VOCAug(
    pre_image_path=PRE_IMAGE_PATH,
    pre_xml_path=PRE_XML_PATH,
    aug_image_save_path=AUG_SAVE_IMAGE_PATH,
    aug_xml_save_path=AUG_SAVE_XML_PATH,
    start_aug_id=None,
    labels=LABELS,
    is_show=False,
)

aug.aug_image()

# cv2.destroyAllWindows()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落难Coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值