【深度学习】带标注的数据增强及albumentations包下载

注:代码修改自另一位博主的文章,很抱歉找不到地址了。
做出以下修改
1、把控制参数放在头部便于修改
2、原代码的增强前后图片和标注存放于一个文件夹,这里存放在新文件夹。
3、做出一些注释
注意:
1、visualize部分未修改,可能会因为我改了变量名跑不通
2、缩放方法可能造成标注偏移,增强之后需要检查
albumentations包下载
如果一直下载不了(time out),用清华镜像源:https://pypi.tuna.tsinghua.edu.cn/simple

pip install albumentations -i https://pypi.tuna.tsinghua.edu.cn/simple
import cv2
from matplotlib import pyplot as plt
import xml.etree.ElementTree as ET
import albumentations as A
import os
import time

# 控制参数
BOX_COLOR = (255, 0, 0)  # Red
TEXT_COLOR = (255, 255, 255)  # White
# 增强张数 original pictures size:62, then total size is 62*GENERATED_PICS_SIZE
GENERATED_PICS_SIZE = 20  # 增强方法在main Compose中修改
# 上级目录
DIR = "D:\\AI\\data6"
# 存储原图片的文件夹名,默认格式未jpg,如果为png需要自行修改
IMAGES_FILE = "images"
# 存储原xml标注的文件夹名
ANNOTATIONS_FILE = "annotations"
# 检查原本的xml标注,object中第几个为bndbox,从0开始计算
OBJ_NUM = 4  # 第五个




def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
    """Visualizes a single bounding box on the image"""
    # x_min, y_min, w, h = bbox
    # x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(
    #     y_min + h)
    x_min, y_min, x_max, y_max = bbox
    print(x_min, y_min, x_max, y_max)

    cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)),
                  color=color, thickness=thickness)

    ((text_width, text_height), _) = cv2.getTextSize(class_name,
                                                     cv2.FONT_HERSHEY_SIMPLEX,
                                                     0.35, 1)
    cv2.rectangle(img, (int(x_min), int(y_min) - int(1.3 * text_height)),
                  (int(x_min) + text_width, int(y_min)), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(int(x_min), int(y_min) - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35,
        color=TEXT_COLOR,
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    plt.axis('off')
    plt.imshow(img)
    plt.show()


def saveNewAnnotation(new_xml_path, new_image_path, xml_path, bboxes, cur_dir):
    in_file = open(os.path.join(xml_path), encoding='utf-8')
    new_file = in_file
    tree = ET.parse(new_file)
    root = tree.getroot()
    root[0].text = "images"
    root[1].text = new_image_path
    root[2].text = cur_dir + '\\images_aug\\' + new_image_path

    idx = 0
    for obj in root.iter('object'):
        # 可能会出现越界问题,根据xml文件中obj位置确定
        obj[OBJ_NUM][0].text = str(round(bboxes[idx][0]))
        obj[OBJ_NUM][1].text = str(round(bboxes[idx][1]))
        obj[OBJ_NUM][2].text = str(round(bboxes[idx][2]))
        obj[OBJ_NUM][3].text = str(round(bboxes[idx][3]))
        idx += 1
    tree.write(new_xml_path, 'UTF-8')


def getAnnotation(xml_path):
    '''
    :param xml_path:
    :return: bboxes, category_ids
    '''

    in_file = open(os.path.join(xml_path), encoding='utf-8')
    try:
        tree = ET.parse(in_file)
    except:
        return [], []
    root = tree.getroot()

    bboxes = []
    category_ids = []

    for obj in root.iter('object'):
        cls = obj.find('name').text

        xmlbox = obj.find('bndbox')
        bbox = [int(float(xmlbox.find('xmin').text)),
                int(float(xmlbox.find('ymin').text)),
                int(float(xmlbox.find('xmax').text)),
                int(float(xmlbox.find('ymax').text))]
        bboxes.append(bbox)
        category_ids.append(cls)
    return bboxes, category_ids


def main(cur_dir):
    images_path = os.path.join(cur_dir, IMAGES_FILE)
    for image_name in os.listdir(images_path):
        image_path = os.path.join(images_path, image_name)
        xml_name = image_name.split('.')[0] + ".xml"
        xmls_path = os.path.join(cur_dir, ANNOTATIONS_FILE)
        xml_path = os.path.join(xmls_path, xml_name)
        if os.path.exists(xml_path):
            # print("the image is : " + image_path)
            # print("the xml of image is : " + xml_path)

            for i in range(GENERATED_PICS_SIZE):
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                # 进行名字更新,1.jpg 更新后 1_001.jpg ~ 1_020.jpg
                new_image_name = image_name.split('.')[0] + "_" + str(i + 1).zfill(3) + ".jpg"
                new_xml_name = xml_name.split('.')[0] + "_" + str(i + 1).zfill(3) + ".xml"
                # 增强后的新图片存放的文件夹地址
                new_images_path = os.path.join(cur_dir, "images_aug")
                # 增强后的新图片png地址
                new_image_path = os.path.join(new_images_path, new_image_name)
                # 增强后的新标注存放的文件夹地址
                new_xmls_path = os.path.join(cur_dir, "annotations_aug")
                # 增强后的新标注xml文件地址
                new_xml_path = os.path.join(new_xmls_path, new_xml_name)

                # print("the new image is : " + new_image_path)
                # print("the new xml of image is : " + new_xml_path)

                bboxes, category_ids = getAnnotation(xml_path=xml_path)
                if len(bboxes) == 0 & len(category_ids) == 0:
                    continue
                category_id_to_name = {}
                for i in range(len(category_ids)):
                    category_id_to_name[category_ids[i]] = category_ids[i]
                # 变换操作
                transform = A.Compose(
                    [
                        A.HorizontalFlip(p=0.5),  # 水平翻转
                        A.VerticalFlip(p=0.5),  # 垂直翻转
                        A.ColorJitter(brightness=0.05, contrast=0.05,  # 改变图像的属性:亮度(brightness)、对比度(contrast)
                                      saturation=0.02,  # 饱和度(saturation)
                                      hue=0.02, always_apply=False, p=1),  # 色调(hue)
                        A.Sharpen(p=1)  # 锐化,加强细节
                    ],
                    bbox_params=A.BboxParams(format='pascal_voc',
                                             label_fields=['category_ids']),
                )
                transformed = transform(image=image, bboxes=bboxes,
                                        category_ids=category_ids)
                image = transformed['image']
                bboxes = transformed['bboxes']
                category_ids = transformed['category_ids']
                # print(bboxes)
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                # 保存图片
                cv2.imencode('.jpg', image)[1].tofile(new_image_path)
                # visualize(image, bboxes, category_ids, category_id_to_name)
                saveNewAnnotation(new_xml_path, new_image_name, xml_path, bboxes, cur_dir)
                print(new_image_name)
        else:
            with open(os.path.join(DIR, "no-annotations.txt"), 'a') as f:
                print(("No this annotations, name of image : " + image_name) , file = f)
        time.sleep(1)


if __name__ == '__main__':
    main(DIR)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

S_u_cheng

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值