目标检测中如何读取.xml后缀格式的标签文件
本文代码是参考B站视频(BV1Q54y1D7vj)讲解yolo v4网络的代码(voc_annotation.py和dataloader两个文件)
文件介绍
在我的项目中有两个文件夹和一个txt文件,一个jpg文件夹保存的是jpg格式图片,xml文件夹里保存的是xml格式的标签文件,两个文件名是对应的,classes.txt保存的是class类型名称。
-
010002.xml标签文件部分内容结构如下
<?xml version="1.0" ?><annotation> <folder>JPEGImages</folder> <filename>010002.jpg</filename> <path></path> <source> <database>Unknown</database> </source> <size> <width>1280</width> <height>1024</height> <depth>1</depth> </size> <segmented>0</segmented> <object> <name>person</name> <pose>Unspecified</pose> <truncated>1</truncated> <difficult>0</difficult> <bndbox> <xmin>978</xmin> <ymin>1</ymin> <xmax>1074</xmax> <ymax>248</ymax> </bndbox>
我们是要提取里面object里面的内容
-
classes.txt内容如下
person bicycle car motorbike aeroplane bus train truck boat traffic light fire hydrant stop sign parking meter bench bird
读取xml文件内容并保存到txt文件中
- 代码
import os import xml.etree.ElementTree as ET # 读取xml文件存放到txt文件中 in_file = open("xml/010002.xml") #构建树状结构 tree = ET.parse(in_file) #获取根节点 root = tree.getroot() # 获得类 def get_classes(classes_path): with open(classes_path, encoding='utf-8') as f: class_names = f.readlines() class_names = [c.strip() for c in class_names] return class_names, len(class_names) classes, _ = get_classes("classes.txt") jpg_dir = 'jpg' image_id = '010002' list_file = open("010002.txt", "w") #将图片路径写入txt文件中 list_file.write(os.path.join(os.path.abspath(jpg_dir), f"{image_id}.jpg")) #循环遍历 XML 文件中所有的对象(object)标签 for obj in root.iter("object"): difficult = obj.find('difficult') #标签名 cls = obj.find("name").text #标签在classes.txt中位置编号 cls_id = classes.index(cls) #如果类别不在指定的类别列表中或者 difficult 标记为 1(即困难标记),则跳过当前对象的处理 if difficult == 1: continue xmlbox = obj.find("bndbox") #边界框两个坐标点 b = [int(float(xmlbox.find("xmin").text)), int(float(xmlbox.find("ymin").text)), int(float(xmlbox.find("xmax").text)), int(float(xmlbox.find("ymax").text)), ] #将数据写入txt文件中 list_file.write(" " + ",".join([str(a) for a in b]) + "," + str(cls_id)) #关闭文件 list_file.close()
- 获得数据的txt文件内容
可以看见这张图像有四个边界框,类别都是personC:\Users\34920\Desktop\读取xml\jpg\010002.jpg 978,1,1074,248,0 402,284,518,546,0 337,353,455,604,0 292,426,415,701,0
读取txt文件,将边界框坐标和类别转换成numpy数组
-
代码
from PIL import Image import numpy as np # 将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image #读取txt文件 with open("010002.txt") as f: train_lines = f.readlines() for train_line in train_lines: line=train_line.split() # print(line,type(line)) #读取图像并转换成RGB图像 image = Image.open(line[0]) # image = cvtColor(image) #line[1:]是一个字符串列表,其中每个字符串代表一组以逗号分隔的整数。 # map(int, box.split(','))将每个字符串box按逗号拆分为整数,并将其转换为整数列表。 # np.array(...)将每个整数列表转换为NumPy数组。 # 外层的np.array(...)将所有这些NumPy数组组合成一个多维数组。 box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) print(box.shape) print(box)
-
输出结果
从结果可以看见已经转成numpy数组格式了,最后一列就是标签类型的id号。