运行下述脚本即可将数据集随机划分为 train
和 val
.'''
实现功能:
对输入数据集进行随机划分
实现方式:
@param: data_path 数据集路径
@param: output_path 数据集输出路径
@param: split_ratio 数据集划分比例
Author:
Lily-20231123
'''
import os
import random
import shutil
import argparse
def split_dataset(input_folder, output_folder, split_ratio=0.8):
train_img_folder = os.path.join(output_folder, 'images', 'train')
test_img_folder = os.path.join(output_folder, 'images', 'val')
train_lab_folder = os.path.join(output_folder, 'labels', 'train')
test_lab_folder = os.path.join(output_folder, 'labels', 'val')
os.makedirs(train_img_folder, exist_ok=True)
os.makedirs(test_img_folder, exist_ok=True)
os.makedirs(train_lab_folder, exist_ok=True)
os.makedirs(test_lab_folder, exist_ok=True)
img_folder = os.path.join(input_folder, 'images')
lab_folder = os.path.join(input_folder, 'labels')
img_files = os.listdir(img_folder)
lab_files = os.listdir(lab_folder)
train_size = int(len(img_files) * split_ratio)
train_img_files = random.sample(img_files, train_size)
test_img_files = [img_file for img_file in img_files if img_file not in train_img_files]
for img_file in train_img_files:
shutil.copy(os.path.join(img_folder, img_file), os.path.join(train_img_folder, img_file))
print('train_img: ', os.path.join(train_img_folder, img_file))
if os.path.exists(os.path.join(lab_folder, img_file[:-3]+"txt")):
shutil.copy(os.path.join(lab_folder, img_file[:-3]+"txt"), os.path.join(train_lab_folder, img_file[:-3]+"txt"))
print('train_lab: ', os.path.join(train_lab_folder, img_file[:-3]+"txt"))
for img_file in test_img_files:
shutil.copy(os.path.join(img_folder, img_file), os.path.join(test_img_folder, img_file))
print('val_img: ', os.path.join(test_img_folder, img_file))
if os.path.exists(os.path.join(lab_folder, img_file[:-3]+"txt")):
shutil.copy(os.path.join(lab_folder, img_file[:-3]+"txt"), os.path.join(test_lab_folder, img_file[:-3]+"txt"))
print('val_lab: ', os.path.join(test_lab_folder, img_file[:-3]+"txt"))
print("train nums: ", train_size)
print("val nums: ", int(len(img_files))-train_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='split datasets.')
parser.add_argument('--data_path', type=str, default=r'D:\domain_apdation\smoke\images',
help="数据集输入路径")
parser.add_argument('--output_path', type=str, default=r'D:\domain_apdation\smoke\images_split',
help="数据集输出路径")
parser.add_argument('--split_ratio', type=float, default=0.8,
help='数据集划分比例')
args = parser.parse_args()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
split_dataset(args.data_path, args.output_path, args.split_ratio)