1. rate_train_val_test.py:按照比例分配VOC格式数据集的train、val和test;
import os
import random
# 随机种子
random.seed(0)
# trainval占总数据集的比例,剩下的比例分配给test
trainval_percent = 1
# train占trainval的比例
train_percent = 0.8
# dataset路径
vocPath = r'D:\Desktop\dataset\SeaShips(7000)'
# xml文件路径
xmlPath = vocPath + r'\Annotations'
# voc的train val test txt保存路径
saveBasePath = vocPath + r'\ImageSets\Main'
temp_xml = os.listdir(xmlPath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)
print("train and val size", tv)
print("traub suze", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
2. voc_xml_to_txt.py:将VOC格式数据集Annotations中的xml文件转为txt;(Tips:注意把class_names修改成自己数据集的类别)
import os.path
import xml.etree.ElementTree as ET
# 类别
class_names = ["ore carrier", "general cargo ship", "bulk cargo carrier", "container ship", "fishing boat", "passenger ship"]
# voc数据集路径
vocPath = r'D:\Desktop\dataset\SeaShips(7000)'
# xml文件路径
xmlPath = vocPath + r'\Annotations'
# xml转换后txt文件存放路径
txtPath = vocPath + r'\txts'
files = []
if not os.path.exists(txtPath):
os.makedirs(txtPath)
for root, dirs, files in os.walk(xmlPath):
None
number = len(files)
print(number)
i = 0
while i < number:
name = files[i][0:-4]
xml_name = name + ".xml"
txt_name = name + ".txt"
xml_file_name = os.path.join(xmlPath, xml_name)
txt_file_name = os.path.join(txtPath, txt_name)
xml_file = open(xml_file_name, encoding='gb18030',errors='ignore')
tree = ET.parse(xml_file)
root = tree.getroot()
w = int(root.find('size').find('width').text)
h = int(root.find('size').find('height').text)
f_txt = open(txt_file_name, 'w+')
content = ""
first = True
for obj in root.iter('object'):
name = obj.find('name').text
# 若只有一类 ,即 class_num = 0
class_num = class_names.index(name)
xmlbox = obj.find('bndbox')
x1 = int(xmlbox.find('xmin').text)
x2 = int(xmlbox.find('xmax').text)
y1 = int(xmlbox.find('ymin').text)
y2 = int(xmlbox.find('ymax').text)
if first:
content += str(class_num) + " " + \
str((x1 + x2) / 2 / w) + " " + str((y1 + y2) / 2 / h) + " " + \
str((x2 - x1) / w) + " " + str((y2 - y1) / h)
first = False
else:
content += "\n" + \
str(class_num) + " " + \
str((x1 + x2) / 2 / w) + " " + str((y1 + y2) / 2 / h) + " " + \
str((x2 - x1) / w) + " " + str((y2 - y1) / h)
print(content)
f_txt.write(content)
f_txt.close()
xml_file.close()
i += 1
3. voc_to_yolo.py:将VOC格式数据集转为YOLO格式;
import os
import shutil
from tqdm import tqdm
# voc数据集路径
vocPath = r'D:\Desktop\dataset\SeaShips(7000)'
# voc的train val test txt保存路径
saveBasePath = vocPath + r'\ImageSets\Main'
# 转换成yolo数据集的保存路径
yoloPath = r'D:\Desktop\dataset\SeaShips(7000)\SeaShips-yolo'
# voc images路径
IMG_PATH = vocPath + r"\JPEGImages"
# voc xml转换后txt文件存放路径
TXT_PATH = vocPath + r"\txts"
# yolo数据集images和label路径
TO_IMG_PATH = yoloPath + r'\images'
TO_TXT_PATH = yoloPath + r'\labels'
data_split = ['train.txt', 'val.txt', 'test.txt']
to_split = ['train2017', 'val2017', 'test2017']
for index, split in enumerate(data_split):
split_path = os.path.join(saveBasePath, split)
to_imgs_path = os.path.join(TO_IMG_PATH, to_split[index])
if not os.path.exists(to_imgs_path):
os.makedirs(to_imgs_path)
to_txts_path = os.path.join(TO_TXT_PATH, to_split[index])
if not os.path.exists(to_txts_path):
os.makedirs(to_txts_path)
f = open(split_path, 'r')
count = 1
for line in tqdm(f.readlines(), desc="{} is copying".format(to_split[index])):
# 复制图片
src_img_path = os.path.join(IMG_PATH, line.strip() + '.jpg')
dst_img_path = os.path.join(to_imgs_path, line.strip() + '.jpg')
if os.path.exists(src_img_path):
shutil.copyfile(src_img_path, dst_img_path)
else:
print("error file: {}".format(src_img_path))
# 复制txt标注文件
src_txt_path = os.path.join(TXT_PATH, line.strip() + '.txt')
dst_txt_path = os.path.join(to_txts_path, line.strip() + '.txt')
if os.path.exists(src_txt_path):
shutil.copyfile(src_txt_path, dst_txt_path)
else:
print("error file: {}".format(src_txt_path))
最终文件结构如下,其中3中的yoloPath即为最终的YOLO数据集。