在参加比赛时,官方给出的标记文件是json格式,自己编写代码使其转换为xml格式。使用的是Python 标准库之 xml.etree.ElementTree。
官方给出的格式:
转换代码如下:
import os, json
import copy
import cv2
import numpy as np
try:
import xml.etree.cElementTree as ElementTree
except ImportError:
import xml.etree.ElementTree as ElementTree
# <Element 'data' at 0x0000021AE8F79540> type:<class 'xml.etree.ElementTree.Element'>
template_file = '../intelligent/1.xml' # xml模板
target_dir = r'./xmll/' # 保存路径
image_dir = r'./picture/' # 图片文件夹
train_file = '../intelligent/labels/' # 存储了图片信息的json文件
file_nameee = ''
json_path = './labels'
# 提取json
def parse_json(d):
arr = np.array([
d['image_name'], d['ignore'],d['category_name'],d['id'],d['occlusion'],d['truncation'],d['bbox']
])
return arr
def chinese_english(str):
if str == 'suv':
str = 'suv'
elif str == '专业作业车':
str = 'Professional_work_car'
elif str == '儿童':
str = 'children'
elif str == '面包车':
str = 'Van_minibus'
elif str == '大客车':
str = 'bus'
elif str == '大货车':
str = 'Big_truck'
elif str == '小货车':
str = 'buggy'
elif str == '成年人':
str = 'adult'
elif str == '电动/摩托三轮车':
str = 'Electric_tricycle'
elif str == '电动/摩托车':
str = 'Electric_motorcycle'
elif str == '自行车':
str = 'bcycle'
elif str == '轿车':
str = 'car'
else:
str = 'other'
return str
def read_xml(in_path):
'''''读取并解析xml文件
in_path: xml路径
return: ElementTree'''
print(in_path)
#prese()解析xml文件
tree = ElementTree.parse(in_path)
return tree
def tiqu(strrr):
new_str = "" # 创建一个空字符串
for ch in strrr:
if ch.isdigit(): # 字符串中的方法,可以直接判断ch是否是数字
new_str += ch
else:
new_str += " "
sub_list = new_str.split() # 对新的字符串切片
num_list = list(map(int, sub_list)) # map方法,使列表中的元素按照指定方式转变
return num_list
dir = os.listdir(json_path)
for file in dir:
trainfiles = json.load(open(train_file+file,encoding='UTF-8'))
len_nn = len(trainfiles)
#tree = ElementTree()
for k, line in enumerate(trainfiles):
arr = parse_json(line)
file_name = arr[0] # 文件名
print(file_name)
file_nameee = file_name
if k == 0:
print(11)
label = arr[2] # 标签名
# 坐标
label = chinese_english(label)
xmin = str(int(arr[6][0] - arr[6][2] / 2))
xmax = str(int(arr[6][0] + arr[6][2] / 2)) #
ymin = str(int(arr[6][1] - arr[6][3] / 2))
ymax = str(int(arr[6][1] + arr[6][3] / 2))
tree = read_xml(template_file) # 解析树
root = tree.getroot() # 根节点
root.find('filename').text = file_name
# size
sz = root.find('size')
imggg = image_dir + file_name
im = cv2.imread(imggg) # 读取图片信息
str1 = arr[4]
str2 = arr[5]
tiqu_str1 = tiqu(str1)
tiqu_str2 = tiqu(str2)
print(tiqu_str1)
print(tiqu_str2)
sz.find('height').text = str(im.shape[0])
sz.find('width').text = str(im.shape[1])
sz.find('depth').text = str(im.shape[2])
# object
obj = root.find('object')
obj.find('name').text = label
if tiqu_str1[0] == 0 and tiqu_str2[0] == 0:
obj.find('truncated').text = str(0)
else:
obj.find('truncated').text = str(1)
bb = obj.find('bndbox')
bb.find('xmin').text = xmin
bb.find('ymin').text = ymin
bb.find('xmax').text = xmax
bb.find('ymax').text = ymax
# 有多个标签需要添加object
else:
label = arr[2]
label = chinese_english(label)
xmin = str(int(arr[6][0] - arr[6][2] / 2))
xmax = str(int(arr[6][0] + arr[6][2] / 2)) #
ymin = str(int(arr[6][1] - arr[6][3] / 2))
ymax = str(int(arr[6][1] + arr[6][3] / 2))
obj_ori = root.find('object')
obj = copy.deepcopy(obj_ori) # 注意这里深拷贝
str1 = arr[4]
str2 = arr[5]
tiqu_str1 = tiqu(str1)
tiqu_str2 = tiqu(str2)
if tiqu_str1[0] == 0 and tiqu_str2[0] == 0:
obj.find('truncated').text = str(0)
else:
obj.find('truncated').text = str(1)
obj.find('name').text = label
bb = obj.find('bndbox')
bb.find('xmin').text = xmin
bb.find('ymin').text = ymin
bb.find('xmax').text = xmax
bb.find('ymax').text = ymax
root.append(obj)
xml_file = file_nameee.replace('jpg','xml')
print("deakpf"+xml_file)
tree.write(target_dir + xml_file, encoding='utf-8')