准备工作:将所需要划分的数据集xml都放到一个文件下。
ps:只适用于yolo目标检测,输出是这样的形式:图片文件名,种类,x,y,w,h !!
代码如下
import random
import os
import xml.etree.ElementTree as ET
sort_names = {
'cat': 0
}
#上面sort_names是用来放你的标签名字的,比如说我这里检测猫 我就自定义'cat':0 如果需要添加一共dog 就写成 'dog':1 这里的标签名字一定要跟你xml里面一样!!!
xml_dir = r'B:\software\deeplearning\pytorch_project\yolov3_steel\hyj_test\annotations'
#xml_dir就是标签xml放置的文件夹地址
xml_names = os.listdir(xml_dir) #读取目录中所有的xml文件名字
xml_numbers = len(xml_names)
train_numbers = int(xml_numbers * 0.8) #这里*0.8就是1000个xml 800个训练 200个验证 可以自己修改
new_xml_names = xml_names #创建一个新的list来存放,后面有用到
random_list_numbers = random.sample(xml_names, train_numbers) #这里的train_numbers不能比xml_names大 不然会报错!!!
for i in random_list_numbers: #这里的random_list_numbers就是随机生成的train里面所有的文件名字
new_xml_names.remove(i) #使用remove将1000个xml中 随机生产的800个train的xml给去除了,剩下的就是val的xml
train_xml = random_list_numbers #train xml的名字800个 随机生成的
val_xml = new_xml_names #val xml200个
#下面这个是将train里面800个xml生成txt文件 文件名为train_xml 里面存放了800条信息对应的种类 中心坐标和检测框长宽
with open('train_xml', 'a') as f:
for i in train_xml:
xml_position = os.path.join(xml_dir, i)
xml_content = ET.parse(xml_position)
root = xml_content.getroot()
picture_name = root.find('filename')
sort_name = root.findall('object/name')
picture_boxes = root.findall('object/bndbox')
hyj = []
hyj.append(picture_name.text)
for sort, boxes in zip(sort_name, picture_boxes):
_sort = sort.text
number_sort = sort_names[_sort]
w = int(boxes[2].text) - int(boxes[0].text)
h = int(boxes[3].text) - int(boxes[1].text)
cx = int(w / 2) + int(boxes[0].text)
cy = int(h / 2) + int(boxes[1].text)
hyj.append(number_sort)
hyj.append(cx)
hyj.append(cy)
hyj.append(w)
hyj.append(h)
hyj2=''
for i in hyj:
hyj2=hyj2+' '+str(i)
f.write(hyj2+'\n')
f.close()
#下面这个是生成了200个val的txt文件同上
with open('val_xml', 'a') as f:
for i in val_xml:
xml_position = os.path.join(xml_dir, i)
xml_content = ET.parse(xml_position)
root = xml_content.getroot()
picture_name = root.find('filename')
sort_name = root.findall('object/name')
picture_boxes = root.findall('object/bndbox')
hyj = []
hyj.append(picture_name.text)
for sort, boxes in zip(sort_name, picture_boxes):
_sort = sort.text
number_sort = sort_names[_sort]
w = int(boxes[2].text) - int(boxes[0].text)
h = int(boxes[3].text) - int(boxes[1].text)
cx = int(w / 2) + int(boxes[0].text)
cy = int(h / 2) + int(boxes[1].text)
hyj.append(number_sort)
hyj.append(cx)
hyj.append(cy)
hyj.append(w)
hyj.append(h)
hyj2=''
for i in hyj:
hyj2=hyj2+' '+str(i)
f.write(hyj2+'\n')
f.close()