源代码说明
包含对应的python代码,测试用到的图像及其标注文件。
https://download.csdn.net/download/u011775793/88631532
什么是训练集,验证集和测试集?
在机器学习中,训练集、验证集和测试集是用于评估模型性能的三个不同数据集。
训练集是指用于训练模型的数据集,通常占总数据集的大部分。通过使用训练集来训练模型,可以使模型学习到数据中的模式和规律,从而提高模型的预测能力。
验证集是指用于评估模型性能的数据集,通常占总数据集的一部分。通过将训练好的模型应用于验证集上,可以评估模型在新数据上的性能,并进行调整和优化。
测试集是指用于最终评估模型性能的数据集,通常占总数据集的一小部分。通过将训练好的模型应用于测试集上,可以得到模型在未知数据上的最终性能。
划分训练集、验证集和测试集的方法有很多种,其中一种常见的方法是将数据集随机分为三部分,分别作为训练集、验证集和测试集。例如,可以将70%的数据作为训练集,15%的数据作为验证集,剩下的15%的数据作为测试集。这种方法称为分层抽样(stratified sampling),可以保证训练集、验证集和测试集中的类别分布与原始数据集相同,从而更好地评估模型的性能。
划分比例的选择可能会影响到模型的性能和泛化能力。因此,在选择划分比例时需要综合考虑多个因素,例如数据集的大小、特征数量、类别分布等。
如何生成YOLOv8训练用的数据集(划分成训练集,验证集和测试集)?
开发思路
输入三个数据集的比例(train_ratio,val_ratio,test_ratio),然后读取根据AnyLabeling标注数据生成的原始YOLO格式的数据集(可以参考另一篇文章:[YOLOv8] 缺陷检测之AnyLabeling标注格式转换成YOLO格式),并对数据集的所有图像进行随机打乱,按照前面输入的比例,把图像和标签文件复制到新创建的YOLOv8数据集的train,val,test子目录下,最后把使用到的classes.txt类别文件,也复制过去,同时生成数据集描述文件data.yaml。
数据集拆分代码
import os
import random
import shutil
import time
import yaml
from wepy import get_logger, init_logger, GLOBAL_ENCODING
class YOLOTrainDataSetGenerator:
def __init__(self, origin_dataset_dir, train_dataset_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
clear_train_dir=False):
# 设置随机数种子
random.seed(1233)
self.origin_dataset_dir = origin_dataset_dir
self.train_dataset_dir = train_dataset_dir
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.clear_train_dir = clear_train_dir
assert self.train_ratio > 0.5, 'train_ratio must larger than 0.5'
assert self.val_ratio > 0.01, 'train_ratio must larger than 0.01'
assert self.test_ratio > 0.01, 'test_ratio must larger than 0.01'
total_ratio = round(self.train_ratio + self.val_ratio + self.test_ratio)
assert total_ratio == 1.0, 'train_ratio + val_ratio + test_ratio must equal 1.0'
def generate(self):
time_start = time.time()
get_logger().info(f'start to split origin data set. \n'
f'origin_dataset_dir:{self.origin_dataset_dir},\n'
f'train_dataset_dir:{self.train_dataset_dir},\n'
f'train_ratio:{self.train_ratio},val_ratio:{self.val_ratio}, test_ratio:{self.test_ratio}')
# 原始数据集的图像目录,标签目录,和类别文件路径
origin_image_dir = os.path.join(self.origin_dataset_dir, 'images')
origin_label_dir = os.path.join(self.origin_dataset_dir, 'labels')
origin_classes_file = os.path.join(self.origin_dataset_dir, 'classes.txt')
if not os.path.exists(origin_classes_file):
get_logger().error(f'classes file is not found. classes_file:{origin_classes_file}')
return
else:
origin_classes = {}
with open(origin_classes_file, mode='r', encoding=GLOBAL_ENCODING) as f:
for cls_id, cls_name in enumerate(f.readlines()):
cls_name = cls_name.strip()
if cls_name != '':
origin_classes[cls_id] = cls_name
# 获取所有原始图像文件名(包括后缀名)
origin_image_filenames = os.listdir(origin_image_dir)
# 随机打乱文件名列表
random.shuffle(origin_image_filenames)
# 计算训练集、验证集和测试集的数量
total_count = len(origin_image_filenames)
train_count = int(total_count * self.train_ratio)
val_count = int(total_count * self.val_ratio)
test_count = total_count - train_count - val_count
# 定义训练集文件夹路径
if self.clear_train_dir and os.path.exists(self.train_dataset_dir):
shutil.rmtree(self.train_dataset_dir, ignore_errors=True)
train_dir = os.path.join(self.train_dataset_dir, 'train')
val_dir = os.path.join(self.train_dataset_dir, 'val')
test_dir = os.path.join(self.train_dataset_dir, 'test')
train_image_dir = os.path.join(train_dir, 'images')
train_label_dir = os.path.join(train_dir, 'labels')
val_image_dir = os.path.join(val_dir, 'images')
val_label_dir = os.path.join(val_dir, 'labels')
test_image_dir = os.path.join(test_dir, 'images')
test_label_dir = os.path.join(test_dir, 'labels')
# 创建训练集输出文件夹
os.makedirs(train_image_dir, exist_ok=True)
os.makedirs(train_label_dir, exist_ok=True)
os.makedirs(val_image_dir, exist_ok=True)
os.makedirs(val_label_dir, exist_ok=True)
os.makedirs(test_image_dir, exist_ok=True)
os.makedirs(test_label_dir, exist_ok=True)
# 将图像和标签文件按设定的ratio划分到训练集,验证集,测试集中
for i, filename in enumerate(origin_image_filenames):
if i < train_count:
output_image_dir = train_image_dir
output_label_dir = train_label_dir
elif i < train_count + val_count:
output_image_dir = val_image_dir
output_label_dir = val_label_dir
else:
output_image_dir = test_image_dir
output_label_dir = test_label_dir
src_img_name_no_ext = os.path.splitext(filename)[0]
src_image_path = os.path.join(origin_image_dir, filename)
src_label_path = os.path.join(origin_label_dir, src_img_name_no_ext + '.txt')
if os.path.exists(src_label_path):
# 复制图像文件
dst_image_path = os.path.join(output_image_dir, filename)
shutil.copy(src_image_path, dst_image_path)
# 复制标签文件
src_label_path = os.path.join(origin_label_dir, src_img_name_no_ext + '.txt')
dst_label_path = os.path.join(output_label_dir, src_img_name_no_ext + '.txt')
shutil.copy(src_label_path, dst_label_path)
else:
get_logger().error(f'no label file found for image file. img_file:{src_image_path}')
train_dir = os.path.normpath(train_dir)
val_dir = os.path.normpath(val_dir)
test_dir = os.path.normpath(test_dir)
get_logger().info(f'generate train, val, test data set. \n'
f'train_count:{train_count}, train_dir:{train_dir}\n'
f'val_count:{val_count}, val_dir:{val_dir}\n'
f'test_count:{test_count}, test_dir:{test_dir}')
# 生成描述训练集的yaml文件
data_dict = {
'train': train_dir,
'val': val_dir,
'test': test_dir,
'nc': len(origin_classes),
'names': origin_classes
}
yaml_file_path = os.path.normpath(os.path.join(self.train_dataset_dir, 'data.yaml'))
with open(yaml_file_path, mode='w', encoding=GLOBAL_ENCODING) as f:
yaml.safe_dump(data_dict, f, default_flow_style=False, allow_unicode=True, encoding=GLOBAL_ENCODING)
get_logger().info(f'generate the `data.yaml`. data:{data_dict}, yaml_file_path:{yaml_file_path}')
get_logger().info('end to ')
if __name__ == '__main__':
init_logger('logs/split_data.log')
g_origin_dataset_dir = 'D:/YOLOv8Train/v8_origin_datasets/mktk_dataset'
g_train_dataset_dir = 'D:/YOLOv8Train/v8_train_datasets/mktk_dataset'
g_train_ratio = 0.7
g_val_ratio = 0.15
g_test_ratio = 0.15
yolo_generator = YOLOTrainDataSetGenerator(g_origin_dataset_dir, g_train_dataset_dir, g_train_ratio, g_val_ratio,
g_test_ratio, True)
yolo_generator.generate()
运行效果
原始YOLO数据集,images包含被标注的图像,labels包含对应图像的标注文件,classes.txt包含标注的类别:
生成用于YOLOv8训练用的数据集,如下图所示: