批量分割mask转xml检测框

代码

import sys
import os
import cv2
import yaml
import copy
import xmltodict
import numpy as np
import skimage.io as io
import xml.etree.ElementTree as ET
from xml.dom.minidom import parse
from glob import glob
sys.path.append('..')


def json_to_xml(json_str):
    # xmltodict库的unparse()json转xml
    # 参数pretty 是格式化xml
    xml_str = xmltodict.unparse(json_str, pretty=1)
    return xml_str


def img2xml(folder: str, filename: str, path: str, width: int, height: int, type: str, name: str,
            pose: str, truncated: int, difficult: int, xmin: int, ymin: int, xmax: int, ymax: int):

    annotation = {'folder': folder, 'filename': filename, 'path': filename}
    source = {'database': "Unknown"}
    annotation['source'] = source
    size = {'width': width, 'height': height, 'depth': 3}
    annotation['size'] = size
    annotation['segmented'] = 0
    # object = {}
    ob = {'type': type, 'name': name, 'pose': pose, 'truncated': truncated, 'difficult': difficult}
    bndbox = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
    ob['bndbox'] = bndbox
    annotation['object'] = ob
    # dic = {}
    dicts = {'annotation': annotation}

    return json_to_xml(dicts)


def writeXML(domTree_path, aimPath, type: str, name: str, pose: str, bndbox: dict):
    if os.path.exists(domTree_path):
        domTree = parse(domTree_path)
        rootNode = domTree.documentElement
        customer_node = domTree.createElement("object")

        type_node = domTree.createElement("type")
        type_text_value = domTree.createTextNode(type)
        type_node.appendChild(type_text_value)  # 把文本节点挂到name_node节点
        customer_node.appendChild(type_node)

        name_node = domTree.createElement("name")
        name_text_value = domTree.createTextNode(name)
        name_node.appendChild(name_text_value)  # 把文本节点挂到name_node节点
        customer_node.appendChild(name_node)

        pose_node = domTree.createElement("pose")
        pose_text_value = domTree.createTextNode(pose)
        pose_node.appendChild(pose_text_value)  # 把文本节点挂到name_node节点
        customer_node.appendChild(pose_node)

        truncated_node = domTree.createElement("truncated")
        truncated_text_value = domTree.createTextNode(str(0))
        truncated_node.appendChild(truncated_text_value)  # 把文本节点挂到name_node节点
        customer_node.appendChild(truncated_node)
        difficult_node = domTree.createElement("difficult")
        difficult_text_value = domTree.createTextNode(str(0))
        difficult_node.appendChild(difficult_text_value)  # 把文本节点挂到name_node节点
        customer_node.appendChild(difficult_node)

        comments_node = domTree.createElement("bndbox")
        xmin = domTree.createElement('xmin')
        ymin = domTree.createElement('ymin')
        xmax = domTree.createElement('xmax')
        ymax = domTree.createElement('ymax')
        # root = {}
        # root['bndbox'] = bndbox
        # s = '<?xml version="1.0" encoding="utf-8"?>'
        xmin_text = domTree.createTextNode(str(bndbox['xmin']))
        ymin_text = domTree.createTextNode(str(bndbox['ymin']))
        xmax_text = domTree.createTextNode(str(bndbox['xmax']))
        ymax_text = domTree.createTextNode(str(bndbox['ymax']))

        xmin.appendChild(xmin_text)
        ymin.appendChild(ymin_text)
        xmax.appendChild(xmax_text)
        ymax.appendChild(ymax_text)

        comments_node.appendChild(xmin)
        comments_node.appendChild(ymin)
        comments_node.appendChild(xmax)
        comments_node.appendChild(ymax)
        customer_node.appendChild(comments_node)

        rootNode.appendChild(customer_node)
        # print(rootNode.nodeName)
        # print(type(domTree))
        # domTree.writexml(domTree_path)
        with open(aimPath, 'w') as f:
            domTree.writexml(f, addindent='', encoding='utf-8')


def prettyXml(element,
              indent,
              newline,
              level=0):  # elemnt为传进来的Elment类,参数indent用于缩进,newline用于换行
    if element:  # 判断element是否有子元素
        if element.text is None or element.text.isspace():  # 如果element的text没有内容
            element.text = newline + indent * (level + 1)
        else:
            element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
    # else:  # 此处两行如果把注释去掉,Element的text也会另起一行
    # element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level
    temp = list(element)  # 将elemnt转成list
    for subelement in temp:
        if temp.index(subelement) < (
                len(temp) - 1):  # 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致
            subelement.tail = newline + indent * (level + 1)
        else:  # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个
            subelement.tail = newline + indent * level
        prettyXml(subelement, indent, newline, level=level + 1)  # 对子元素进行递归操作
    return element


def img2xml_multiobj(tmpPath: str, aimPath: str, folder: str, filename: str,
                     path: str, width: int, height: int, objs: list):

    if len(objs) > 0:
        obj = objs[0]
        # print(obj)
        bndBox = obj['bndbox']

        f = open(tmpPath, 'w')
        f.writelines(img2xml(folder, filename, path, width, height,
                             obj['type'], obj['name'], obj['pose'], obj['truncated'], obj['difficult'],
                             bndBox['xmin'], bndBox['ymin'], bndBox['xmax'], bndBox['ymax']))
        f.close()

        if len(objs) > 1:
            # for i in objs:
            for i in range(1, len(objs)):
                o = objs[i]
                bn = o['bndbox']
                bndbox = {'xmin': bn['xmin'], 'ymin': bn['ymin'], 'xmax': bn['xmax'], 'ymax': bn['ymax']}
                writeXML(tmpPath, aimPath, o['type'], o['name'], o['pose'], bndbox)

        domTree = ET.parse(tmpPath)
        root = domTree.getroot()
        root = prettyXml(root, '\t', '\n')
        tree = ET.ElementTree(root)
        tree.write(tmpPath)


def getMultiObjs_voc_withYaml(oriImgPath, labelPath, savePath, yamlPath=''):
    if os.path.exists(yamlPath):
        f = open(yamlPath, encoding='utf-8')
        y = yaml.load(f, Loader=yaml.FullLoader)
        f.close()
        label_masks = y['label_names']
    else:
        raise FileNotFoundError('yaml file not found!')

    fileName = oriImgPath.split(os.sep)[-1]
    saveXmlPath = savePath + os.sep + fileName[:-4] + '.xml'

    labelImg = io.imread(labelPath) if isinstance(labelPath, str) else labelPath
    fileName = oriImgPath.split(os.sep)[-1]
    imgShape = labelImg.shape
    imgHeight = imgShape[0]
    imgWidth = imgShape[1]
    imgPath = oriImgPath
    objs = []
    for k, v in label_masks.items():
        # print(k)
        # print(v)
        ma = copy.deepcopy(labelImg)
        ma[ma != int(v)] = 0

        if np.sum(ma) > 0:
            ma1 = copy.deepcopy(labelImg)
            # Consider only one label at a time, making the others 0 and the label 255 in cycle
            ma1[ma1 != int(v)] = 0
            ma1[ma1 != 0] = 255

            _, labels, stats, centroids = cv2.connectedComponentsWithStats(ma1)
            # num_labels:所有连通域的数目
            # labels:图像上每一像素的标记,用数字1、2、3…表示(不同的数字表示不同的连通域)
            # stats:每一个标记的统计信息,是一个5列的矩阵,每一行对应每个连通区域的外接矩形的x、y、width、height和面积
            # centroids:连通域的中心点
            del ma1

            statsShape = stats.shape
            # print(statsShape[0])
            for i in range(1, statsShape[0]):
                st = stats[i, :]
                # print(st)
                width = st[2]
                height = st[3]
                xmin = st[0]
                ymin = st[1]

                xmax = xmin + width
                ymax = ymin + height

                ob = {'type': 'bndbox', 'name': k, 'pose': 'Unspecified', 'truncated': 0, 'difficult': 0}
                bndbox = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}

                ob['bndbox'] = bndbox
                # 判断删除类别中不满足要求的框,可自行修改
                if v in range(1, 19):
                    if width > 20 and height > 20 and st[4] >= 300:
                        objs.append(ob)
                if v == 19:
                    if st[4] >= 5:
                        objs.append(ob)
        del ma

    img2xml_multiobj(saveXmlPath, saveXmlPath, "image", fileName, imgPath, imgWidth, imgHeight, objs)
    objs.clear()


if __name__ == "__main__":
    path = ''
    init_path = '%s/images' % path
    mask_path = '%s/mask' % path

    yaml_file = '%s/label_names.yaml' % path
    save_xml = '%s/Annotations' % path
    if not os.path.exists(save_xml):
        os.mkdir(save_xml)

    mask_images_list = glob(os.path.join(mask_path, "*.png"))
    init_images_list = glob(os.path.join(init_path, "*.png"))

    for mask_image, init_image in zip(mask_images_list, init_images_list):
        print(init_image)
        getMultiObjs_voc_withYaml(init_image, mask_image, save_xml, yaml_file)

label_names.yaml格式

label_names:
    Tag1: 1
    Tag2: 2
    类别: 掩码像素值
    ....

xml格式

<annotation>
	<folder>image</folder>
	<filename>000001_left0_Affine.png</filename>
	<path>000001_left0_Affine.png</path>
	<source>
		<database>Unknown</database>
	</source>
	<size>
		<width>1500</width>
		<height>1000</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<type>bndbox</type>
		<name>Tag1</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>948</xmin>
			<ymin>848</ymin>
			<xmax>1081</xmax>
			<ymax>913</ymax>
		</bndbox>
	</object>

参考

https://github.com/guchengxi1994/mask2json

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值