目标检测数据集使用Albumentation数据库进行数据增强

import xml.dom.minidom
import cv2
from albumentations import(
        BboxParams, RandomGamma, Compose, Blur, CenterCrop, HueSaturationValue,
        MotionBlur, Cutout, RandomBrightness,RandomContrast
)
import os
import glob
def read_xml(path):
    exp_xml = []
    dom = xml.dom.minidom.parse(path) ## parse()获取DOM对象
    root = dom.documentElement #获取根结点
    img_name = root.getElementsByTagName("filename")[0] # 通过dom对象或根元素,再根据标签名获取元素节点,是个列表
    #exp_xml.append(img_name.childNodes[0].data+".jpg")
    exp_xml.append(img_name.childNodes[0].data)
    #print("fileneme:%s"%img_name.childNodes[0].data)
    label = root.getElementsByTagName("name")[0]
    exp_xml.append(label.childNodes[0].data)
    bonbox_xmin = root.getElementsByTagName("xmin")[0]
    exp_xml.append(bonbox_xmin.childNodes[0].data)
    bonbox_ymin = root.getElementsByTagName("ymin")[0]
    exp_xml.append(bonbox_ymin.childNodes[0].data)
    bonbox_xmax = root.getElementsByTagName("xmax")[0]
    exp_xml.append(bonbox_xmax.childNodes[0].data)
    bonbox_ymax = root.getElementsByTagName("ymax")[0]
    exp_xml.append(bonbox_ymax.childNodes[0].data)
    return exp_xml
def modify_xml(path,bbox,new_img_name,aug_file):
    new_dom = xml.dom.minidom.parse(path)
    new_root = new_dom.documentElement
    new_img_xml_name = new_root.getElementsByTagName("filename")[0]
    new_img_xml_name.childNodes[0].data = new_img_name
    new_bonbox_xmin = new_root.getElementsByTagName("xmin")[0]
    new_bonbox_xmin.childNodes[0].data = bbox[0]
    new_bonbox_ymin = new_root.getElementsByTagName("ymin")[0]
    new_bonbox_ymin.childNodes[0].data = bbox[1]
    new_bonbox_xmax = new_root.getElementsByTagName("xmax")[0]
    new_bonbox_xmax.childNodes[0].data = bbox[2]
    new_bonbox_ymax = new_root.getElementsByTagName("ymax")[0]
    new_bonbox_ymax.childNodes[0].data = bbox[3]
    with open(os.path.join(aug_file,
        aug_file+"\\{}.xml".format(new_img_name)), 'w') as fh: new_dom.writexml(fh)
def visualize_bbox(img, bbox, class_id, class_idx_to_name):
    bbox = list(bbox)
    x_min, y_min, x_max, y_max = bbox
    x_min = int(x_min)
    y_min = int(y_min)
    x_max = int(x_max)
    y_max = int(y_max)
    image = cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255,0,0), 2)
    class_name = class_idx_to_name[class_id]
    ((text_width, text_height), _) = cv2.getTextSize(class_name,cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
    cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min +text_width, y_min), (255,0,0), -1)
    cv2.putText(image, class_name, (x_min, y_min - int(0.3 * text_height)),cv2.FONT_HERSHEY_SIMPLEX, 0.35,(255,255,255), lineType=cv2.LINE_AA)
    return image
def get_aug(aug, min_area=0., min_visibility=0.):
    return Compose(aug, bbox_params=BboxParams(format='pascal_voc',
    min_area=min_area,
    min_visibility=min_visibility,
    label_fields=["category_id"]))
def augment():
    aug = Compose([###需要修改的地方
        # Blur(blur_limit = 7,p = 0.3),#模糊处理
        # RandomGamma(gamma_limit=(80,120),p=0.5),#伽马变换
        # #CenterCrop(height=400, width=400, p=0.2),#中心裁剪
        # HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30,val_shift_limit=20, p=0.3),#HSV偏移
        # MotionBlur(blur_limit=7, p=0.5),#动态模糊
        # Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0,always_apply=False, p=0.2)#挖小洞
          #RandomBrightness(limit=0.5, p=1),
          #RandomContrast(limit=2.3, p=0.5)
                    ])
    return aug
def get_data(xml_date_path,img_path):
    xml_date = read_xml(xml_date_path)
    img = cv2.imread(img_path + '\\' + xml_date[0])
    bbox =[[int(xml_date[2]),int(xml_date[3]),int(xml_date[4]),int(xml_date[5])]]
    return bbox,img
def keep_aug_img(annotations):
    aug_img = annotations['image'].copy()
    for idx, bbox in enumerate(annotations['bboxes']):
        bbox = list(bbox)
        x_min, y_min, x_max, y_max = bbox
        x_min = int(x_min)
        y_min = int(y_min)
        x_max = int(x_max)
        y_max = int(y_max)
    aug_bbox = [x_min,y_min,x_max,y_max]
    return aug_img,aug_bbox
def visualize(annotations, category_id_to_name):
    img = annotations['image'].copy()
    for idx, bbox in enumerate(annotations['bboxes']):
        img = visualize_bbox(img, bbox, annotations['category_id'][idx],category_id_to_name)
    return img
def main():
    xml_img_path = r"D:\DataSet\origin_test" # 存放 xml 和 img 数据地址###需要修改的地方
    aug_file = r"D:\DataSet\aug_testlabel" # 增强 xml 存放地址 ###需要修改的地方
    img_file=r"D:\DataSet\aug_test" #增强img存放地址 ###需要修改的地方
    shample = 1  # 需要增强的次数  ###需要修改的地方
    for n in range(shample):
        num= 0
        print(" 第 %d 次 "%n)
        for xml_name in glob.glob(xml_img_path + "/*.xml"): # 循环
            #print(" 第 %d 张图片 "%num)
            bbox,img = get_data(xml_name,xml_img_path) # 获得 img 以及 xml 中bbox 坐标。 ↪
            annotations = {'image': img, 'bboxes': bbox, 'category_id': [1]}
            aug = augment()
            augmented = aug(**annotations)
            #category_id_to_name = {1:"juanyuanzi"}
            #img,bbox = visualize(augmented, category_id_to_name)
            #cv2.imshow("x",img)
            #cv2.waitKey(0)
            """ 可视化 """
            aug_img,aug_bbox = keep_aug_img(augmented) # 增强后的图像及相对应的坐标
            #a=xml_name[23:-4]
            #cv2.imwrite(aug_file+"\\aug_img.jpg",aug_img)
            cv2.imwrite(img_file + "\\hrg{}.jpg".format(xml_name[23:-4]), aug_img)#img 保存 ↪ ###需要修改的地方
            #new_xml_path = os.path.split(aug_file + "\\aug_img.jpg")[1]  # 获取增强xml 地址
            new_xml_path =os.path.split(aug_file+"\\hrg{}.jpg".format(xml_name[23:-4]))[1] # 获取增强xml 地址 ###需要修改的地方
            new_xml_name = new_xml_path.split(".")[0] # 获取 xml 名字
            modify_xml(xml_name,aug_bbox,new_xml_name,aug_file) # 对xml 文件进行修改 ↪
            num += 1
if __name__ == "__main__":
    main()

(1)python glob.glob()函数:用于匹配文件路径,返回所有匹配的文件路径列表。
匹配符包括*、“?”和"[]",其中“*”表示匹配任意字符串,“?”匹配任意单个字符,[0-9]与[a-z]表示匹配0-9的单个数字与a-z的单个字符。
(2)Python format 格式化函数:Python2.6 开始,新增了一种格式化字符串的函数 str.format(),它增强了字符串格式化的功能。
基本语法是通过 {} 和 : 来代替以前的 % 。
(3)os.path.split():按照路径将文件名和路径分割开。
1.PATH指一个文件的全路径作为参数:
2.如果给出的是一个目录和文件名,则输出路径和文件名
3.如果给出的是一个目录名,则输出路径和为空文件名
(4)Python split() 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则仅分隔 num 个子字符串。
str.split(str="", num=string.count(str)).
参数:
str – 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。
num – 分割次数。
返回值:返回分割后的字符串列表。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值