1.踩坑背景
做的事情很简单,从xml中读取图像标注信息,信息入下图所示,需要提取其中的标注信息(object数量不固定)
2.错误代码
from lxml import etree
def get_text_from_xml(tree, tag):
return tree.xpath('//%s'%tag)[-1].text
def load_xml(infile):
tree = etree.parse(infile)
objs = tree.xpath('//object')
annos = []
for obj in objs:
name = get_text_from_xml(obj, 'name')
xmin = int(get_text_from_xml(obj, 'xmin'))
ymin = int(get_text_from_xml(obj, 'ymin'))
xmax = int(get_text_from_xml(obj, 'xmax'))
ymax = int(get_text_from_xml(obj, 'ymax'))
w = xmax - xmin
h = ymax - ymin
annos.append(dict(
category=name,
bbox=[xmin, ymin, w, h]))
return annos
我的想法很简单,通过xpath找到所有tag为object的节点,然后遍历转换为自己想要的格式,但是结果出来,发现每个解析出来的object信息都一模一样,且都是xml中最后一个object的信息(顿时头顶闪现三个问号),直接看截图吧,直观一点
很显然,结果非常不正常 ,问题出在哪里呢,太想当然了!!!
debug发现循环解析objs的时候,get_text_from_xml函数中的xpath返回的都是整个xml中的所有object,所以每次取最后一个,当然最终结果会是一模一样的(目前只是观察到了这个现象,原因还未查明)
3.解决办法
我想到两种解决办法
方法一:
使用lxml的getchildren这个函数,逐层解析,代码如下
def load_xml(infile):
root = tree.getroot()
annos = []
for obj in root.getchildren():
if obj.tag != 'object':
continue
anno = {}
for attr in obj.getchildren():
if attr.tag == 'name':
anno['category'] = attr.text
elif attr.tag == 'bndbox':
for coor in attr.getchildren():
if coor.tag == 'xmin':
xmin = int(coor.text)
elif coor.tag == 'xmax':
xmax = int(coor.text)
elif coor.tag == 'ymin':
ymin = int(coor.text)
else:
ymax = int(coor.text)
w = xmax - xmin
h = ymax - ymin
anno['bbox'] = [xmin, ymin, w, h]
annos.append(anno)
return annos
简单粗暴解决问题,但是过多的嵌套导致代码很挫
方法二:
仍然使用xpath函数,取出xml中所有需要解析的字段(因为name/xmin/max/ymin/ymax都是会一起出现的) ,直接上代码
def load_xml(infile):
tree = etree.parse(infile)
names = tree.xpath('//name')
names = [i for i in names if i.getparent().tag=='object']
xmins = tree.xpath('//xmin')
xmaxs = tree.xpath('//xmax')
ymins = tree.xpath('//ymin')
ymaxs = tree.xpath('//ymax')
xmins = [int(i.text) for i in xmins]
xmaxs = [int(i.text) for i in xmaxs]
ymins = [int(i.text) for i in ymins]
ymaxs = [int(i.text) for i in ymaxs]
ws = [xmaxs[i] - xmins[i] for i in range(len(names))]
hs = [ymaxs[i] - ymins[i] for i in range(len(names))]
annos = [dict(category=names[i].text, bbox=[xmins[i], ymins[i], ws[i], hs[i]]) for i in range(len(names))]
return annos
names多做一次过滤,是因为其他xml中其他节点也会有tag为name的信息,这样就会导致names和坐标信息无法一一对应,因此增加了这一行。