XML 标签读取
from xml.etree import ElementTree
copyfil
路径和文件配置
根据自己的路径配置,主要用来存储处理好的数据
import os
from shutil import copyfile
from xml.etree import ElementTree
a=0
train_path='.//datasets//data//train'
val_path='.//datasets//data//value'
train_txt='.//datasets//data//train.txt'
sourse='.//datasets//data//VOCdevkit//VOC2012//JPEGImages'
path_prefix = './/datasets//data//VOCdevkit//VOC2012//ImageSets//Main//person_train.txt'
val_txt='.//datasets//data//val.txt'
with open(path_prefix,'r') as f:
filenames=f.readlines()
train_file = open(train_txt, 'w')
val_file = open(val_txt, 'w')
for filename in filenames:
filename =filename[ :-3]
# 找到文件名字
path1='./datasets//data//VOCdevkit//VOC2012//Annotations'+'//'+filename.rstrip()+'.xml'
XML读取
用a来分训练集和验证集
# 1、XML解析根路径
tree = ElementTree.parse(path1)
root = tree.getroot()
size_tree = root.find('size')
width = float(size_tree.find('width').text)
height = float(size_tree.find('height').text)
for object_tree in root.findall('object'):
if object_tree.find('name').text=='person':
for bounding_box in object_tree.iter('bndbox'):
xmin = float(bounding_box.find('xmin').text) / width
ymin = float(bounding_box.find('ymin').text) / height
xmax = float(bounding_box.find('xmax').text) / width
ymax = float(bounding_box.find('ymax').text) / height
bounding_boxes = []
bounding_box = [(xmin+xmax)/2., (ymin+ymax)/2., xmax-xmin, ymax-ymin]
bounding_boxes.append(bounding_box)
class_name = object_tree.find('name').text
# # 将类别进行one_hot编码
# one_hot_class = self.on_hot(class_name)
# one_hot_classes.append(one_hot_class)
# 我现在只要人
if a<1000:
label_path = train_path + '//' + filename.rstrip() + '.txt'
label_file = open(label_path, 'w')
copyfile(sourse+'//'+filename.rstrip()+'.jpg',train_path+'//'+filename.rstrip()+'.jpg')
print('copy done')
label_file.write('1'+' '+' '.join([str(i) for i in bounding_box])+'\n')
print('mask done')
a += 1
print(a)
train_file.write(train_path+'//'+filename.rstrip()+'.jpg'+'\n')
elif a >= 1000and a <1300:
val_label_path = val_path + '//' + filename.rstrip() + '.txt'
val_label_file = open(val_label_path, 'w')
copyfile(sourse + '//' + filename.rstrip() + '.jpg',val_path + '//' + filename.rstrip() + '.jpg')
print('ssssss')
val_label_file.write('1' + ' ' + ' '.join([str(i) for i in bounding_box]) + '\n')
print('ssssss')
a += 1
print(a)
val_file.write(val_path + '//' +filename.rstrip() + '.jpg' + '\n')
else :
break