代码
import sys
import os
import cv2
import yaml
import copy
import xmltodict
import numpy as np
import skimage.io as io
import xml.etree.ElementTree as ET
from xml.dom.minidom import parse
from glob import glob
sys.path.append('..')
def json_to_xml(json_str):
# xmltodict库的unparse()json转xml
# 参数pretty 是格式化xml
xml_str = xmltodict.unparse(json_str, pretty=1)
return xml_str
def img2xml(folder: str, filename: str, path: str, width: int, height: int, type: str, name: str,
pose: str, truncated: int, difficult: int, xmin: int, ymin: int, xmax: int, ymax: int):
annotation = {'folder': folder, 'filename': filename, 'path': filename}
source = {'database': "Unknown"}
annotation['source'] = source
size = {'width': width, 'height': height, 'depth': 3}
annotation['size'] = size
annotation['segmented'] = 0
# object = {}
ob = {'type': type, 'name': name, 'pose': pose, 'truncated': truncated, 'difficult': difficult}
bndbox = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
ob['bndbox'] = bndbox
annotation['object'] = ob
# dic = {}
dicts = {'annotation': annotation}
return json_to_xml(dicts)
def writeXML(domTree_path, aimPath, type: str, name: str, pose: str, bndbox: dict):
if os.path.exists(domTree_path):
domTree = parse(domTree_path)
rootNode = domTree.documentElement
customer_node = domTree.createElement("object")
type_node = domTree.createElement("type")
type_text_value = domTree.createTextNode(type)
type_node.appendChild(type_text_value) # 把文本节点挂到name_node节点
customer_node.appendChild(type_node)
name_node = domTree.createElement("name")
name_text_value = domTree.createTextNode(name)
name_node.appendChild(name_text_value) # 把文本节点挂到name_node节点
customer_node.appendChild(name_node)
pose_node = domTree.createElement("pose")
pose_text_value = domTree.createTextNode(pose)
pose_node.appendChild(pose_text_value) # 把文本节点挂到name_node节点
customer_node.appendChild(pose_node)
truncated_node = domTree.createElement("truncated")
truncated_text_value = domTree.createTextNode(str(0))
truncated_node.appendChild(truncated_text_value) # 把文本节点挂到name_node节点
customer_node.appendChild(truncated_node)
difficult_node = domTree.createElement("difficult")
difficult_text_value = domTree.createTextNode(str(0))
difficult_node.appendChild(difficult_text_value) # 把文本节点挂到name_node节点
customer_node.appendChild(difficult_node)
comments_node = domTree.createElement("bndbox")
xmin = domTree.createElement('xmin')
ymin = domTree.createElement('ymin')
xmax = domTree.createElement('xmax')
ymax = domTree.createElement('ymax')
# root = {}
# root['bndbox'] = bndbox
# s = '<?xml version="1.0" encoding="utf-8"?>'
xmin_text = domTree.createTextNode(str(bndbox['xmin']))
ymin_text = domTree.createTextNode(str(bndbox['ymin']))
xmax_text = domTree.createTextNode(str(bndbox['xmax']))
ymax_text = domTree.createTextNode(str(bndbox['ymax']))
xmin.appendChild(xmin_text)
ymin.appendChild(ymin_text)
xmax.appendChild(xmax_text)
ymax.appendChild(ymax_text)
comments_node.appendChild(xmin)
comments_node.appendChild(ymin)
comments_node.appendChild(xmax)
comments_node.appendChild(ymax)
customer_node.appendChild(comments_node)
rootNode.appendChild(customer_node)
# print(rootNode.nodeName)
# print(type(domTree))
# domTree.writexml(domTree_path)
with open(aimPath, 'w') as f:
domTree.writexml(f, addindent='', encoding='utf-8')
def prettyXml(element,
indent,
newline,
level=0): # elemnt为传进来的Elment类,参数indent用于缩进,newline用于换行
if element: # 判断element是否有子元素
if element.text is None or element.text.isspace(): # 如果element的text没有内容
element.text = newline + indent * (level + 1)
else:
element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
# else: # 此处两行如果把注释去掉,Element的text也会另起一行
# element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level
temp = list(element) # 将elemnt转成list
for subelement in temp:
if temp.index(subelement) < (
len(temp) - 1): # 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致
subelement.tail = newline + indent * (level + 1)
else: # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个
subelement.tail = newline + indent * level
prettyXml(subelement, indent, newline, level=level + 1) # 对子元素进行递归操作
return element
def img2xml_multiobj(tmpPath: str, aimPath: str, folder: str, filename: str,
path: str, width: int, height: int, objs: list):
if len(objs) > 0:
obj = objs[0]
# print(obj)
bndBox = obj['bndbox']
f = open(tmpPath, 'w')
f.writelines(img2xml(folder, filename, path, width, height,
obj['type'], obj['name'], obj['pose'], obj['truncated'], obj['difficult'],
bndBox['xmin'], bndBox['ymin'], bndBox['xmax'], bndBox['ymax']))
f.close()
if len(objs) > 1:
# for i in objs:
for i in range(1, len(objs)):
o = objs[i]
bn = o['bndbox']
bndbox = {'xmin': bn['xmin'], 'ymin': bn['ymin'], 'xmax': bn['xmax'], 'ymax': bn['ymax']}
writeXML(tmpPath, aimPath, o['type'], o['name'], o['pose'], bndbox)
domTree = ET.parse(tmpPath)
root = domTree.getroot()
root = prettyXml(root, '\t', '\n')
tree = ET.ElementTree(root)
tree.write(tmpPath)
def getMultiObjs_voc_withYaml(oriImgPath, labelPath, savePath, yamlPath=''):
if os.path.exists(yamlPath):
f = open(yamlPath, encoding='utf-8')
y = yaml.load(f, Loader=yaml.FullLoader)
f.close()
label_masks = y['label_names']
else:
raise FileNotFoundError('yaml file not found!')
fileName = oriImgPath.split(os.sep)[-1]
saveXmlPath = savePath + os.sep + fileName[:-4] + '.xml'
labelImg = io.imread(labelPath) if isinstance(labelPath, str) else labelPath
fileName = oriImgPath.split(os.sep)[-1]
imgShape = labelImg.shape
imgHeight = imgShape[0]
imgWidth = imgShape[1]
imgPath = oriImgPath
objs = []
for k, v in label_masks.items():
# print(k)
# print(v)
ma = copy.deepcopy(labelImg)
ma[ma != int(v)] = 0
if np.sum(ma) > 0:
ma1 = copy.deepcopy(labelImg)
# Consider only one label at a time, making the others 0 and the label 255 in cycle
ma1[ma1 != int(v)] = 0
ma1[ma1 != 0] = 255
_, labels, stats, centroids = cv2.connectedComponentsWithStats(ma1)
# num_labels:所有连通域的数目
# labels:图像上每一像素的标记,用数字1、2、3…表示(不同的数字表示不同的连通域)
# stats:每一个标记的统计信息,是一个5列的矩阵,每一行对应每个连通区域的外接矩形的x、y、width、height和面积
# centroids:连通域的中心点
del ma1
statsShape = stats.shape
# print(statsShape[0])
for i in range(1, statsShape[0]):
st = stats[i, :]
# print(st)
width = st[2]
height = st[3]
xmin = st[0]
ymin = st[1]
xmax = xmin + width
ymax = ymin + height
ob = {'type': 'bndbox', 'name': k, 'pose': 'Unspecified', 'truncated': 0, 'difficult': 0}
bndbox = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
ob['bndbox'] = bndbox
# 判断删除类别中不满足要求的框,可自行修改
if v in range(1, 19):
if width > 20 and height > 20 and st[4] >= 300:
objs.append(ob)
if v == 19:
if st[4] >= 5:
objs.append(ob)
del ma
img2xml_multiobj(saveXmlPath, saveXmlPath, "image", fileName, imgPath, imgWidth, imgHeight, objs)
objs.clear()
if __name__ == "__main__":
path = ''
init_path = '%s/images' % path
mask_path = '%s/mask' % path
yaml_file = '%s/label_names.yaml' % path
save_xml = '%s/Annotations' % path
if not os.path.exists(save_xml):
os.mkdir(save_xml)
mask_images_list = glob(os.path.join(mask_path, "*.png"))
init_images_list = glob(os.path.join(init_path, "*.png"))
for mask_image, init_image in zip(mask_images_list, init_images_list):
print(init_image)
getMultiObjs_voc_withYaml(init_image, mask_image, save_xml, yaml_file)
label_names.yaml格式
label_names:
Tag1: 1
Tag2: 2
类别: 掩码像素值
....
xml格式
<annotation>
<folder>image</folder>
<filename>000001_left0_Affine.png</filename>
<path>000001_left0_Affine.png</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>1500</width>
<height>1000</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<type>bndbox</type>
<name>Tag1</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>948</xmin>
<ymin>848</ymin>
<xmax>1081</xmax>
<ymax>913</ymax>
</bndbox>
</object>
参考
https://github.com/guchengxi1994/mask2json