# -*- coding:utf-8 -*-
import os
import json
import numpy as np
#from xml.etree import ElementTree as etree
from xml.etree.ElementTree import Element
from xml.etree.ElementTree import SubElement
from xml.etree.ElementTree import ElementTree
imagePath = r'E:\Desktop\SteelCoilsDetection\test\images'
jsonPath = r'E:\Desktop\SteelCoilsDetection\test\json'
savePath = r'E:\Desktop\SteelCoilsDetection\test\xml'
jsonList = os.listdir(jsonPath)
for jsonName in jsonList:
print(jsonName)
readPath = os.path.join(jsonPath, jsonName)
# 打开json文件
with open(readPath, 'r') as file_loader:
jsonDic = json.load(file_loader)
# print(jsonDic.keys())
# dict_keys(['version', 'flags', 'shapes', 'imagePath', 'imageData', 'imageHeight', 'imageWidth'])
# 生成xml文件
annotation = Element('annotation')
folder = SubElement(annotation, 'folder')
folder.text = "images"
filename = SubElement(annotation, 'filename')
filename.text = jsonName.split('.')[0]
path = SubElement(annotation, 'path')
path.text = imagePath + jsonName.split('.')[0]
source = SubElement(annotation, 'source')
database = SubElement(source, 'database')
database.text = "Unknown"
size = SubElement(annotation, 'size')
width = SubElement(size, 'width')
width.text = str(jsonDic['imageWidth'])
height = SubElement(size, 'height')
height.text = str(jsonDic['imageHeight'])
depth = SubElement(size, 'depth')
depth.text = "3"
segmented = SubElement(annotation, 'segmented')
segmented.text = "0"
for shape in jsonDic['shapes']:
if shape["label"] == 'a':
continue
object = SubElement(annotation, 'object')
name = SubElement(object, 'name')
name.text = shape["label"]
pose = SubElement(object, 'pose')
pose.text = 'Unspecified'
truncated = SubElement(object, 'truncated')
truncated.text = str(0)
difficult = SubElement(object, 'difficult')
difficult.text = str(0)
points = shape['points']
mritx = np.array(points)
xxmin = min(mritx[:, 0])
xxmax = max(mritx[:, 0])
yymin = min(mritx[:, 1])
yymax = max(mritx[:, 1])
bndbox = SubElement(object, 'bndbox')
xmin = SubElement(bndbox, 'xmin')
xmin.text = str(int(xxmin))
ymin = SubElement(bndbox, 'ymin')
ymin.text = str(int(yymin))
xmax = SubElement(bndbox, 'xmax')
xmax.text = str(int(xxmax))
ymax = SubElement(bndbox, 'ymax')
ymax.text = str(int(yymax))
tree = ElementTree(annotation)
tree.write(os.path.join(savePath, jsonName.split('.')[0]+'.xml'), encoding = 'utf-8')
美化:
# -*- coding:utf-8 -*-
import os
from xml.etree import ElementTree # 导入ElementTree模块
# elemnt为传进来的Elment类,参数indent用于缩进,newline用于换行
def prettyXml(element, indent, newline, level = 0):
# 判断element是否有子元素
if element:
# 如果element的text没有内容
if element.text == None or element.text.isspace():
element.text = newline + indent * (level + 1)
else:
element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
# 此处两行如果把注释去掉,Element的text也会另起一行
#else:
#element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level
temp = list(element) # 将elemnt转成list
for subelement in temp:
# 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致
if temp.index(subelement) < (len(temp) - 1):
subelement.tail = newline + indent * (level + 1)
else: # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个
subelement.tail = newline + indent * level
# 对子元素进行递归操作
prettyXml(subelement, indent, newline, level = level + 1)
dir = r'E:\Desktop\SteelCoilsDetection\test\xml'
for fileName in os.listdir(dir):
print(fileName)
tree = ElementTree.parse(os.path.join(dir, fileName)) #解析test.xml这个文件,该文件内容如上文
root = tree.getroot() #得到根元素,Element类
prettyXml(root, '\t', '\n') # 执行美化方法
#ElementTree.dump(root) #显示出美化后的XML内容
tree.write(os.path.join(dir, fileName), encoding = 'utf-8')
借鉴:https://zhuanlan.zhihu.com/p/54269963