代码如下:
import os
import random
import argparse
parser = argparse.ArgumentParser()
# 1.xml文件的文件夹路径,根据自己的数据进行修改。 xml一般存放在Annotations下
parser.add_argument('--xml_path', default='D:/数据集/dataset/SeaShips(7000)/Annotations', type=str, help='input xml label path')
# 2.保存 数据集划分生成的txt文件 的文件夹路径。
parser.add_argument('--txt_path', default='D:/数据集/dataset/SeaShips(7000)/ImageSets/Main', type=str, help='output txt label path')
opt = parser.parse_args()
trainval_percent = 1 # 3.训练+验证集一共所占的比例,剩下的就是测试集了。
train_percent = 0.9 # 4.训练集在训练集和验证集总集合中占的比例
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
os.makedirs(txtsavepath)
num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')
for i in list_index:
name = total_xml[i][:-4] + '\n' # 5.-4可以用来去掉文件后缀名
if i in trainval:
file_trainval.write(name)
if i in train:
file_train.write(name)
else:
file_val.write(name)
else:
file_test.write(name)
file_trainval.close()
file_train.close()
file_val.close()
file_test.close()