python读取并修改XML文件

在VOC格式的SSD训练数据标签中:

1、添加图片宽高等信息;

2、坐标越界检查;

3、重设宽高尺寸;

4、获取标注框长宽分布;

5、删除特定类别;

6、删除没有标注物体的xml和对应图片;

7、根据xml生成可视化txt。

# coding: utf-8
import os
from xml.etree.ElementTree import ElementTree,Element
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random


def read_xml(in_path):
    '''''读取并解析xml文件
       in_path: xml路径
       return: ElementTree'''
    tree = ElementTree()
    tree.parse(in_path)
    return tree

def write_xml(tree, out_path):
    '''''将xml文件写出
       tree: xml树
       out_path: 写出路径'''
    tree.write(out_path, encoding="utf-8", xml_declaration=True)

def if_match(node, kv_map):
    '''''判断某个节点是否包含所有传入参数属性
       node: 节点
       kv_map: 属性及属性值组成的map'''
    for key in kv_map:
        if node.get(key) != kv_map.get(key):
            return False
    return True

# ----------------search -----------------
def find_nodes(tree, path):
    '''''查找某个路径匹配的所有节点
       tree: xml树
       path: 节点路径'''
    return tree.findall(path)

def get_node_by_keyvalue(nodelist, kv_map):
    '''''根据属性及属性值定位符合的节点,返回节点
       nodelist: 节点列表
       kv_map: 匹配属性及属性值map'''
    result_nodes = []
    for node in nodelist:
        if if_match(node, kv_map):
            result_nodes.append(node)
    return result_nodes

# ---------------change ----------------------
def change_node_properties(nodelist, kv_map, is_delete=False):
    '''修改/增加 /删除 节点的属性及属性值
       nodelist: 节点列表
       kv_map:属性及属性值map'''
    for node in nodelist:
        for key in kv_map:
            if is_delete:
                if key in node.attrib:
                    del node.attrib[key]
            else:
                node.set(key, kv_map.get(key))

def change_node_text(nodelist, text, is_add=False, is_delete=False):
    '''''改变/增加/删除一个节点的文本
       nodelist:节点列表
       text : 更新后的文本'''
    for node in nodelist:
        if is_add:
            node.text += text
        elif is_delete:
            node.text = ""
        else:
            node.text = text

def create_node(tag, property_map, content):
    '''新造一个节点
       tag:节点标签
       property_map:属性及属性值map
       content: 节点闭合标签里的文本内容
       return 新节点'''
    element = Element(tag, property_map)
    element.text = content
    return element

def add_child_node(nodelist, element):
    '''''给一个节点添加子节点
       nodelist: 节点列表
       element: 子节点'''
    for node in nodelist:
        node.append(element)


def del_node_by_tagkeyvalue(nodelist, tag, kv_map):
    '''''同过属性及属性值定位一个节点,并删除之
       nodelist: 父节点列表
       tag:子节点标签
       kv_map: 属性及属性值列表'''
    for parent_node in nodelist:
        children = parent_node.getchildren()
        for child in children:
            if child.tag == tag and if_match(child, kv_map):
                parent_node.remove(child)

if __name__ == "__main__":
    path = u"E:\\UnderWaterDetection\\720_480\\train\\box"  # xml文件所在的目录
    files = os.listdir(path)  # 得到文件夹下所有文件名称
    width_arr = []
    height_arr = []
    for xmlFile in files:  # 遍历文件夹
        #print(xmlFile)
        xmlPath = os.path.join(path, xmlFile)
        xmlPartName = xmlFile.split(".")[0]

        imageFile = xmlPartName + ".jpg"
        imagePath = os.path.join(u"E:\\UnderWaterDetection\\720_480\\train\\image\\", imageFile)
        fileHandle = open(imagePath, 'rb')
        img = Image.open(fileHandle)

        ################ 1. 读取xml文件  ##########
        tree = read_xml(xmlPath)
        root = tree.getroot()
        '''
        ################ 删除特定类别  ##########
        for obj in root.iter('object'):  # 获取object节点中的name子节点
            name = obj.find("name").text
            if name == 'waterweeds':
                print(xmlFile)
                root.remove(obj)
                write_xml(tree, xmlPath)
        '''

        ################ 删除标注物体为空的标注文件与对应图片  ##########
        obj = root.find("object")
        if obj == None:
            fileHandle.close()
            os.remove(imagePath)
            os.remove(xmlPath)
            print(xmlFile)

        '''
        ################ 根据xml生成可视化工具所需txt ###############
        txtPath = u"E:\\UnderWaterDetection\\车牌字符检测标注工具\\data\\720_480\\train\\image\\标注信息\\" + xmlPartName + ".txt"
        f = open(txtPath, 'w')
        size = root.find("size")
        orgin_width = size.find("width").text
        orgin_height = size.find("height").text
        f.write("<?xml version='1.0' encoding='GB2312'?>\n")
        f.write("<info>\n")
        f.write("	<src width=\"" + str(orgin_width) + "\" height = \"" + str(orgin_height) + "\" depth = \"" + str(3) + "\">" + imageFile + "</src>\n")
        for obj in root.iter('object'):  # 获取object节点中的name子节点
            name = obj.find("name")
            bndbox_node = obj.find("bndbox")
            xmin_node = bndbox_node.find("xmin")
            ymin_node = bndbox_node.find("ymin")
            xmax_node = bndbox_node.find("xmax")
            ymax_node = bndbox_node.find("ymax")
            f.write("	<object id=\"02720149-03FF-4866-" + str(random.randint(0, 1000000)) + "-" + str(random.randint(0, 1000000)) + "\">\n")
            f.write("		<rect lefttopx=\"" + xmin_node.text + "\" lefttopy = \"" + ymin_node.text + "\" rightbottomx = \"" + xmax_node.text + "\" rightbottomy = \"" + ymax_node.text + "\"></rect>\n")
            f.write("		<type>" + name.text + "</type>\n")
            f.write("		<descriinfo></descriinfo>\n")
            f.write("		<modifydate>2020-03-20 16:32:00</modifydate>\n")
            f.write("	</object>\n")
        f.write("</info>\n")
        f.close()
        '''
        '''
        ################ 训练集重设尺寸 ###############
        new_width = 300
        new_hieght = 300
        size = root.find("size")
        orgin_width = size.find("width")
        orgin_height = size.find("height")
        resize_img = img.resize((new_width, new_hieght), Image.ANTIALIAS)
        resize_img.save("E:\\UnderWaterDetection\\300_300\\train\\image\\" + imageFile)
        radio_width = float(new_width) / int(orgin_width.text)
        radio_height = float(new_hieght) / int(orgin_height.text)
        orgin_width.text = str(new_width)
        orgin_height.text = str(new_hieght)
        
        for obj in root.iter('object'):  # 获取object节点中的name子节点
            bndbox_node = obj.find("bndbox")
            xmin_node = bndbox_node.find("xmin")
            xmin_node.text = str(int(int(xmin_node.text) * radio_width))
            ymin_node = bndbox_node.find("ymin")
            ymin_node.text = str(int(int(ymin_node.text) * radio_height))
            xmax_node = bndbox_node.find("xmax")
            xmax_node.text = str(int(int(xmax_node.text) * radio_width))
            ymax_node = bndbox_node.find("ymax")
            ymax_node.text = str(int(int(ymax_node.text) * radio_height))
        '''
        '''
        ################ 坐标越界检查 ###############
        for obj in root.iter('object'):  # 获取object节点中的name子节点
            bndbox_node = obj.find("bndbox")
            xmin_node = bndbox_node.find("xmin")
            ymin_node = bndbox_node.find("ymin")
            xmax_node = bndbox_node.find("xmax")
            ymax_node = bndbox_node.find("ymax")
            if int(xmin_node.text) <= 0:
                xmin_node.text = "0"
            if int(ymin_node.text) <= 0:
                ymin_node.text = "0"
            if int(xmax_node.text) > img.size[0]:
                xmax_node.text = str(img.size[0])
            if int(ymax_node.text) > img.size[1]:
                ymax_node.text = str(img.size[1])
        '''
        '''
        #################  获取训练集标注物体长宽分布 ##############
        for obj in root.iter('object'):  # 获取object节点中的name子节点
            bndbox_node = obj.find("bndbox")
            xmin_node = bndbox_node.find("xmin")
            ymin_node = bndbox_node.find("ymin")
            xmax_node = bndbox_node.find("xmax")
            ymax_node = bndbox_node.find("ymax")
            width = int(xmax_node.text) -  int(xmin_node.text)
            height = int(ymax_node.text) - int(ymin_node.text)
            width_arr.append(width)
            height_arr.append(height)
        '''
        '''
        #################  添加图片信息 ##############
        a = create_node("size", {}, "") # 新建节点
        w = create_node("width", {}, str(img.size[0])) # 新建节点
        a.append(w)
        h = create_node("height", {}, str(img.size[1]))  # 新建节点
        a.append(h)
        d = create_node("depth", {}, "3")  # 新建节点
        a.append(d)
        root.append(a)
        '''
        '''
        #################  保存修改后的xml ##############
        savePath = os.path.join("E:\\UnderWaterDetection\\300_300\\train\\box\\", xmlFile)
        write_xml(tree, savePath)
        '''
    '''
    #################  生成散点图 ##############
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    ax1.set_title('训练集长宽分布')
    plt.xlabel('width')
    plt.ylabel('height')
    ax1.scatter(width_arr, height_arr, c='r', marker='.')
    plt.show()
    '''

 

参考文献:

  1. https://blog.csdn.net/wsp_1138886114/article/details/86576900
  2. http://www.py3study.com/Article/details/id/6864.html
  3. https://blog.csdn.net/u013951454/article/details/88602220
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值