"""
--function 1.通过在xml_base中的xml将img_base中每个类图像复制到shutil_copy_to
2.在txt_base生成train.txt,val.txt文件
--author 图像组——***
--time 2024.5.9
--note 完成Voc标签标注后,将annotation从图像文件夹中拿出来,再根据annotation将每个打过标签的图像拿出来,并按照10%作为验证集划分训练集和验证集
"""
import os
import shutil
from lxml import etree
def parse_xml_to_dict(xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def findAllxml(base):
for root, ds, fs in os.walk(base):
for f in fs:
# read xml
with open(os.path.join(root, f)) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = parse_xml_to_dict(xml)["annotation"]
img_name = data["filename"]
yield img_name, f.split(".")[0]
def findAllFile(base):
for root, ds, fs in os.walk(base):
for f in fs:
yield (os.path.join(root, f), f)
def main():
xml_base = r'E:\ProjectImage\Dataset\Annotations'
img_base = r'E:\ProjectImage\Dataset\c_medcien'
shutil_copy_to = r"E:\ProjectImage\Dataset\cm55\JPEGImages"
txt_base = r"E:\ProjectImage\Dataset\cm55\ImageSets\Main"
xml_list = []
train_list = []
eval_list = []
all_path = []
all_filename = []
for (path, filename) in findAllFile(img_base):
all_path.append(path)
all_filename.append(filename)
for img_name, xml_name in findAllxml(xml_base):
idx = all_filename.index(img_name)
shutil.copy(all_path[idx], shutil_copy_to)
xml_list.append(xml_name)
# 训练集:验证集 = 10:1,每十张取一张
for idx, xml_name in enumerate(xml_list):
if idx % 10 == 0:
eval_list.append(xml_name)
else:
train_list.append(xml_name)
f_train = open(os.path.join(txt_base, "train.txt"), "w")
for line in train_list:
f_train.write(line + '\n')
f_train.close()
f_eval = open(os.path.join(txt_base, "val.txt"), "w")
for line in eval_list:
f_eval.write(line + '\n')
f_eval.close()
if __name__ == '__main__':
main()
自用:完成Voc标签标注后提取相关图像,并划分训练集和验证集
于 2024-05-09 23:01:09 首次发布