将YOLO格式数据集划分为:训练集(70%)、验证集(10%)、测试集(20%)
import os
import shutil
import random
random.seed(0)
def split_data(file_path, txt_path, new_file_path, train_rate, val_rate, test_rate):
eachclass_image = [file for file in os.listdir(file_path) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
total = len(eachclass_image)
random.shuffle(eachclass_image)
train_images = eachclass_image[:int(train_rate * total)]
val_images = eachclass_image[int(train_rate * total):int((train_rate + val_rate) * total)]
test_images = eachclass_image[int((train_rate + val_rate) * total):]
for dataset_type, images in [('train', train_images), ('val', val_images), ('test', test_images)]:
images_path = os.path.join(new_file_path, dataset_type, 'images')
labels_path = os.path.join(new_file_path, dataset_type, 'labels')
if not os.path.exists(images_path):
os.makedirs(images_path)
if not os.path.exists(labels_path):
os.makedirs(labels_path)
for image in images:
old_image_path = os.path.join(file_path, image)
new_image_path = os.path.join(images_path, image)
shutil.copy(old_image_path, new_image_path)
base_filename = os.path.splitext(image)[0]
old_label_path = os.path.join(txt_path, base_filename + '.txt')
new_label_path = os.path.join(labels_path, base_filename + '.txt')
if os.path.exists(old_label_path):
shutil.copy(old_label_path, new_label_path)
else:
print(f"Label file for {image} does not exist, skipping.")
if __name__ == '__main__':
image_path = "dataset/images"
txt_path = "dataset/labels"
new_dataset_path = "dataset/divided"
split_data(image_path, txt_path, new_dataset_path, train_rate=0.7, val_rate=0.1, test_rate=0.2)