前言
在用目标检测做分类的时候,模型不会判断物体属性,只是根据学习来判断可能是什么标签,并给出bbox. 但是研究内容主要是针对人,所有桌子等其他干扰必须去除.想到是否可以对所有样本数据新增一个label="person".意思就是对原来所有的xml文件中,有object的label翻倍并修改新增的object的label为"person".
代码
import os
import numpy as np
def main(xml_read_data, xml_write_data):
lines, num_start, num_finish, num_len = read_xml(xml_read_data)
add_person_object(xml_write_data,lines, num_start, num_finish, num_len)
def read_xml(xml):
#lines, object, num_start, num_finish 在每次调用函数的时候都需要清零,否则其值会影响下一次调用
lines = []
object = []
num_start = []
num_finish = []
with open(xml) as f:
for i,line in enumerate(f):
if line == "\t<object>\n":
object.append(line)
num_start.append(i)
elif line == '\t</object>\n':
num_finish.append(i)
lines.append(line)
num_len = len(num_start)
return lines, num_start, num_finish, num_len
def add_person_object(xml_write_data,lines, num_start, num_finish, num_len, multi_object_lines = [],final_lines = ''):
multi_object_lines = [] #每次调用函数都需要将multi_object_lines清零,否则会继续调用上一步函数产生的multi_object_lines值
if num_len == 1:
object_lines = lines[num_start[0]:num_finish[0]+1]
tail = lines[-1]
list1 = lines[0:-1]
txt = ''.join(object_lines)
txt = txt.replace('normal', 'person')
txt = txt.replace('fall', 'person')
final_lines = final_lines.join(list1) + txt + str(tail)
with open(xml_write_data, 'w') as f:
f.write(final_lines)
if num_len > 1:
for i in range(num_len):
object_lines = lines[num_start[i]:num_finish[i]+1]
multi_object_lines+=object_lines
tail = lines[-1]
list1= lines[0:-1]
txt = ''.join(multi_object_lines)
txt = txt.replace('label1','person')
txt = txt.replace('label2','person')
final_lines = final_lines.join(list1)+txt+str(tail)
with open(xml_write_data,'w') as f:
f.write(final_lines)
if __name__ == '__main__':
root_dir = ''
xml_data_before = root_dir + 'Annotations/'
xml_data_finished = root_dir + 'added_annotations/'
count = 0
data = []
for file in os.listdir(xml_data_before):
count+=1
data.append(file)
for i, file in enumerate(os.listdir(xml_data_before)):
xml_read_data = xml_data_before+file
xml_write_data = xml_data_finished+file
print("the process is {}/{}".format(i,count))
main(xml_read_data,xml_write_data)