import os
import shutil
import random
def split_dataset(img_dir, labels_dir, output_dir, train_ratio=0.80, val_ratio=0.20, image_ext='.jpg', label_ext='.txt'):
# 确保训练和验证的比例之和为1
assert abs(train_ratio + val_ratio - 1) < 0.0001, "Ratios must sum up to 1"
# 创建输出目录
imgs_train_dir = os.path.join(output_dir, 'imgs', 'train')
imgs_val_dir = os.path.join(output_dir, 'imgs', 'val')
labels_train_dir = os.path.join(output_dir, 'labels', 'train')
labels_val_dir = os.path.join(output_dir, 'labels', 'val')
for d in [imgs_train_dir, imgs_val_dir, labels_train_dir, labels_val_dir]:
os.makedirs(d, exist_ok=True)
# 列出所有图像文件
images = [f for f in os.listdir(img_dir) if f.endswith(image_ext)]
# 排序和洗牌
images.sort()
random.shuffle(images)
# 计算分割索引
total_images = len(images)
train_end = int(total_images * train_ratio)
train_images = images[:train_end]
val_images = images[train_end:]
# 复制文件到指定文件夹
def copy_files(files, src_img_dir, src_labels_dir, dest_img_dir, dest_labels_dir):
for img in files:
img_path = os.path.join(src_img_dir, img)
label_path = os.path.join(src_labels_dir, os.path.splitext(img)[0] + label_ext)
if os.path.exists(label_path):
shutil.copy(img_path, dest_img_dir)
shutil.copy(label_path, dest_labels_dir)
else:
print(f"Label file for {img} not found. Skipping this file.")
# 复制文件到训练、验证目录
copy_files(train_images, img_dir, labels_dir, imgs_train_dir, labels_train_dir)
copy_files(val_images, img_dir, labels_dir, imgs_val_dir, labels_val_dir)
print(f"Dataset split into train ({len(train_images)}) and val ({len(val_images)})")
if __name__ == "__main__":
img_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original\original_select_15\images"
labels_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original\original_select_15\labels"
output_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original_select15"
split_dataset(img_dir, labels_dir, output_dir)
img_dir,输入图片集地址
labels_dir,输入标签集地址
output_dir,会自动生成两个子文件夹imgs/labels。这里生成的imgs目录名要改为images,才能给yolo跑
划分训练集、验证集、测试集:
import os
import shutil
import random
def split_dataset(img_dir, labels_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1, image_ext='.jpg',
label_ext='.txt'):
# 确保训练、验证、测试的比例之和为1
assert abs(train_ratio + val_ratio + test_ratio - 1) < 0.0001, "Ratios must sum up to 1"
# 创建输出目录
imgs_train_dir = os.path.join(output_dir, 'imgs', 'train')
imgs_val_dir = os.path.join(output_dir, 'imgs', 'val')
imgs_test_dir = os.path.join(output_dir, 'imgs', 'test')
labels_train_dir = os.path.join(output_dir, 'labels', 'train')
labels_val_dir = os.path.join(output_dir, 'labels', 'val')
labels_test_dir = os.path.join(output_dir, 'labels', 'test')
for d in [imgs_train_dir, imgs_val_dir, imgs_test_dir, labels_train_dir, labels_val_dir, labels_test_dir]:
os.makedirs(d, exist_ok=True)
# 列出所有图像文件
images = [f for f in os.listdir(img_dir) if f.endswith(image_ext)]
# 排序和洗牌
images.sort()
random.shuffle(images)
# 计算分割索引
total_images = len(images)
train_end = int(total_images * train_ratio)
val_end = train_end + int(total_images * val_ratio)
train_images = images[:train_end]
val_images = images[train_end:val_end]
test_images = images[val_end:]
# 复制文件到指定文件夹
def copy_files(files, src_img_dir, src_labels_dir, dest_img_dir, dest_labels_dir):
for img in files:
img_path = os.path.join(src_img_dir, img)
label_path = os.path.join(src_labels_dir, os.path.splitext(img)[0] + label_ext)
if os.path.exists(label_path):
shutil.copy(img_path, dest_img_dir)
shutil.copy(label_path, dest_labels_dir)
else:
print(f"Label file for {img} not found. Skipping this file.")
# 复制文件到训练、验证、测试目录
copy_files(train_images, img_dir, labels_dir, imgs_train_dir, labels_train_dir)
copy_files(val_images, img_dir, labels_dir, imgs_val_dir, labels_val_dir)
copy_files(test_images, img_dir, labels_dir, imgs_test_dir, labels_test_dir)
print(f"Dataset split into train ({len(train_images)}), val ({len(val_images)}), and test ({len(test_images)})")
if __name__ == "__main__":
img_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original\original_select_15\img"
labels_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original\original_select_15\labels"
output_dir = r"D:\Myself文档\Dataset\官方Dataset\spain\original_select15"
split_dataset(img_dir, labels_dir, output_dir)