只需要将 图片文件所在目录、标注文件所在目录和保存目录的路径进行修改即可,划分比例可以自行修改。
在代码中找到如下位置修改数值即可
# 设置数据集划分比例
train_percent = 0.75
val_percent = 0.15
test_percent = 0.15
完整代码如下所示
import shutil
import random
import os
import argparse
# 检查文件夹是否存在,如果不存在则创建该文件夹
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
# 复制文件函数,如果源文件或目标文件不存在则打印错误信息
def copy_files(srcImage, dstImage, srcLabel, dstLabel):
# 检查文件是否存在
if os.path.exists(srcImage) and os.path.exists(srcLabel):
shutil.copyfile(srcImage, dstImage)
shutil.copyfile(srcLabel, dstLabel)
else:
missing_files = [srcImage, srcLabel]
missing_files = [file for file in missing_files if not os.path.exists(file)]
print(f"文件不存在: {', '.join(missing_files)}")
# 查找匹配的图片文件,支持多种图片格式
def find_matching_images(image_dir, name_without_ext):
# 支持的图片格式列表
supported_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tiff']
# 遍历所有支持的格式,找到匹配的图片文件
for img_format in supported_formats:
img_path = os.path.join(image_dir, f"{name_without_ext}{img_format}")
if os.path.exists(img_path):
return img_path
return None
# 主函数,用于划分数据集
def main(image_dir, txt_dir, save_dir):
# 创建保存数据集的目录结构
mkdir(save_dir)
images_dir = os.path.join(save_dir, 'images')
labels_dir = os.path.join(save_dir, 'labels')
img_train_path = os.path.join(images_dir, 'train')
img_test_path = os.path.join(images_dir, 'test')
img_val_path = os.path.join(images_dir, 'val')
label_train_path = os.path.join(labels_dir, 'train')
label_test_path = os.path.join(labels_dir, 'test')
label_val_path = os.path.join(labels_dir, 'val')
mkdir(images_dir)
mkdir(labels_dir)
mkdir(img_train_path)
mkdir(img_test_path)
mkdir(img_val_path)
mkdir(label_train_path)
mkdir(label_test_path)
mkdir(label_val_path)
# 设置数据集划分比例
train_percent = 0.75
val_percent = 0.15
test_percent = 0.15
# 获取所有标注文件
total_txt = os.listdir(txt_dir)
num_txt = len(total_txt)
list_all_txt = range(num_txt)
# 计算训练集、验证集、测试集的数目
num_train = int(num_txt * train_percent)
num_val = int(num_txt * val_percent)
num_test = num_txt - num_train - num_val
# 随机选取文件索引用于训练集、验证集、测试集的划分
train = random.sample(list_all_txt, num_train)
val_test = [i for i in list_all_txt if not i in train]
val = random.sample(val_test, num_val)
# 打印训练集、验证集、测试集的数目
print("训练集数目:{}, 验证集数目:{},测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
# 遍历所有标注文件,根据划分结果复制到相应目录
for i in list_all_txt:
name = total_txt[i].replace('.txt', '') # 去除文件扩展名
# 查找与标注文件匹配的图片文件
srcImage = find_matching_images(image_dir, name)
srcLabel = os.path.join(txt_dir, name + '.txt')
# 如果找到了匹配的图片文件,则进行复制操作
if srcImage:
if i in train:
dst_train_Image = os.path.join(img_train_path, os.path.basename(srcImage))
dst_train_Label = os.path.join(label_train_path, name + '.txt')
copy_files(srcImage, dst_train_Image, srcLabel, dst_train_Label)
elif i in val:
dst_val_Image = os.path.join(img_val_path, os.path.basename(srcImage))
dst_val_Label = os.path.join(label_val_path, name + '.txt')
copy_files(srcImage, dst_val_Image, srcLabel, dst_val_Label)
else:
dst_test_Image = os.path.join(img_test_path, os.path.basename(srcImage))
dst_test_Label = os.path.join(label_test_path, name + '.txt')
copy_files(srcImage, dst_test_Image, srcLabel, dst_test_Label)
else:
# 如果没有找到匹配的图片文件,则打印错误信息
print(f"没有找到匹配的图片文件: {name}")
# 程序入口判断
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='将数据集划分为训练集、验证集和测试集')
parser.add_argument('--image-dir', type=str, default=r'D:/img', help='图片文件所在目录')
parser.add_argument('--txt-dir', type=str, default=r'D:/txt', help='标注文件所在目录')
parser.add_argument('--save-dir', default=r'D:/split', type=str, help='保存目录')
args = parser.parse_args()
image_dir = args.image_dir
txt_dir = args.txt_dir
save_dir = args.save_dir
# 调用主函数
main(image_dir, txt_dir, save_dir)