import xml.dom.minidom
import cv2
from albumentations import(
BboxParams, RandomGamma, Compose, Blur, CenterCrop, HueSaturationValue,
MotionBlur, Cutout, RandomBrightness,RandomContrast
)
import os
import glob
def read_xml(path):
exp_xml = []
dom = xml.dom.minidom.parse(path) ## parse()获取DOM对象
root = dom.documentElement #获取根结点
img_name = root.getElementsByTagName("filename")[0] # 通过dom对象或根元素,再根据标签名获取元素节点,是个列表
#exp_xml.append(img_name.childNodes[0].data+".jpg")
exp_xml.append(img_name.childNodes[0].data)
#print("fileneme:%s"%img_name.childNodes[0].data)
label = root.getElementsByTagName("name")[0]
exp_xml.append(label.childNodes[0].data)
bonbox_xmin = root.getElementsByTagName("xmin")[0]
exp_xml.append(bonbox_xmin.childNodes[0].data)
bonbox_ymin = root.getElementsByTagName("ymin")[0]
exp_xml.append(bonbox_ymin.childNodes[0].data)
bonbox_xmax = root.getElementsByTagName("xmax")[0]
exp_xml.append(bonbox_xmax.childNodes[0].data)
bonbox_ymax = root.getElementsByTagName("ymax")[0]
exp_xml.append(bonbox_ymax.childNodes[0].data)
return exp_xml
def modify_xml(path,bbox,new_img_name,aug_file):
new_dom = xml.dom.minidom.parse(path)
new_root = new_dom.documentElement
new_img_xml_name = new_root.getElementsByTagName("filename")[0]
new_img_xml_name.childNodes[0].data = new_img_name
new_bonbox_xmin = new_root.getElementsByTagName("xmin")[0]
new_bonbox_xmin.childNodes[0].data = bbox[0]
new_bonbox_ymin = new_root.getElementsByTagName("ymin")[0]
new_bonbox_ymin.childNodes[0].data = bbox[1]
new_bonbox_xmax = new_root.getElementsByTagName("xmax")[0]
new_bonbox_xmax.childNodes[0].data = bbox[2]
new_bonbox_ymax = new_root.getElementsByTagName("ymax")[0]
new_bonbox_ymax.childNodes[0].data = bbox[3]
with open(os.path.join(aug_file,
aug_file+"\\{}.xml".format(new_img_name)), 'w') as fh: new_dom.writexml(fh)
def visualize_bbox(img, bbox, class_id, class_idx_to_name):
bbox = list(bbox)
x_min, y_min, x_max, y_max = bbox
x_min = int(x_min)
y_min = int(y_min)
x_max = int(x_max)
y_max = int(y_max)
image = cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255,0,0), 2)
class_name = class_idx_to_name[class_id]
((text_width, text_height), _) = cv2.getTextSize(class_name,cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min +text_width, y_min), (255,0,0), -1)
cv2.putText(image, class_name, (x_min, y_min - int(0.3 * text_height)),cv2.FONT_HERSHEY_SIMPLEX, 0.35,(255,255,255), lineType=cv2.LINE_AA)
return image
def get_aug(aug, min_area=0., min_visibility=0.):
return Compose(aug, bbox_params=BboxParams(format='pascal_voc',
min_area=min_area,
min_visibility=min_visibility,
label_fields=["category_id"]))
def augment():
aug = Compose([###需要修改的地方
# Blur(blur_limit = 7,p = 0.3),#模糊处理
# RandomGamma(gamma_limit=(80,120),p=0.5),#伽马变换
# #CenterCrop(height=400, width=400, p=0.2),#中心裁剪
# HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30,val_shift_limit=20, p=0.3),#HSV偏移
# MotionBlur(blur_limit=7, p=0.5),#动态模糊
# Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0,always_apply=False, p=0.2)#挖小洞
#RandomBrightness(limit=0.5, p=1),
#RandomContrast(limit=2.3, p=0.5)
])
return aug
def get_data(xml_date_path,img_path):
xml_date = read_xml(xml_date_path)
img = cv2.imread(img_path + '\\' + xml_date[0])
bbox =[[int(xml_date[2]),int(xml_date[3]),int(xml_date[4]),int(xml_date[5])]]
return bbox,img
def keep_aug_img(annotations):
aug_img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
bbox = list(bbox)
x_min, y_min, x_max, y_max = bbox
x_min = int(x_min)
y_min = int(y_min)
x_max = int(x_max)
y_max = int(y_max)
aug_bbox = [x_min,y_min,x_max,y_max]
return aug_img,aug_bbox
def visualize(annotations, category_id_to_name):
img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
img = visualize_bbox(img, bbox, annotations['category_id'][idx],category_id_to_name)
return img
def main():
xml_img_path = r"D:\DataSet\origin_test" # 存放 xml 和 img 数据地址###需要修改的地方
aug_file = r"D:\DataSet\aug_testlabel" # 增强 xml 存放地址 ###需要修改的地方
img_file=r"D:\DataSet\aug_test" #增强img存放地址 ###需要修改的地方
shample = 1 # 需要增强的次数 ###需要修改的地方
for n in range(shample):
num= 0
print(" 第 %d 次 "%n)
for xml_name in glob.glob(xml_img_path + "/*.xml"): # 循环
#print(" 第 %d 张图片 "%num)
bbox,img = get_data(xml_name,xml_img_path) # 获得 img 以及 xml 中bbox 坐标。 ↪
annotations = {'image': img, 'bboxes': bbox, 'category_id': [1]}
aug = augment()
augmented = aug(**annotations)
#category_id_to_name = {1:"juanyuanzi"}
#img,bbox = visualize(augmented, category_id_to_name)
#cv2.imshow("x",img)
#cv2.waitKey(0)
""" 可视化 """
aug_img,aug_bbox = keep_aug_img(augmented) # 增强后的图像及相对应的坐标
#a=xml_name[23:-4]
#cv2.imwrite(aug_file+"\\aug_img.jpg",aug_img)
cv2.imwrite(img_file + "\\hrg{}.jpg".format(xml_name[23:-4]), aug_img)#img 保存 ↪ ###需要修改的地方
#new_xml_path = os.path.split(aug_file + "\\aug_img.jpg")[1] # 获取增强xml 地址
new_xml_path =os.path.split(aug_file+"\\hrg{}.jpg".format(xml_name[23:-4]))[1] # 获取增强xml 地址 ###需要修改的地方
new_xml_name = new_xml_path.split(".")[0] # 获取 xml 名字
modify_xml(xml_name,aug_bbox,new_xml_name,aug_file) # 对xml 文件进行修改 ↪
num += 1
if __name__ == "__main__":
main()
(1)python glob.glob()函数:用于匹配文件路径,返回所有匹配的文件路径列表。
匹配符包括*、“?”和"[]",其中“*”表示匹配任意字符串,“?”匹配任意单个字符,[0-9]与[a-z]表示匹配0-9的单个数字与a-z的单个字符。
(2)Python format 格式化函数:Python2.6 开始,新增了一种格式化字符串的函数 str.format(),它增强了字符串格式化的功能。
基本语法是通过 {} 和 : 来代替以前的 % 。
(3)os.path.split():按照路径将文件名和路径分割开。
1.PATH指一个文件的全路径作为参数:
2.如果给出的是一个目录和文件名,则输出路径和文件名
3.如果给出的是一个目录名,则输出路径和为空文件名
(4)Python split() 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则仅分隔 num 个子字符串。
str.split(str="", num=string.count(str)).
参数:
str – 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。
num – 分割次数。
返回值:返回分割后的字符串列表。