import os
import random
import shutil
from tqdm import tqdm
# 功能:将变化检测原始数据划分按比例为训练集、验证机和测试集
def split_dataset(
input_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
):
"""
要求变化检测原始数据集子目录格式为:
--DATASET_DIR
--A
--B
--label
生成结果为:
--OUT_DATASET_DIR
--train
--A
--B
--label
--val
--A
--B
--label
--test
--A
--B
--label
要求对应前景影像、后景影像以及变化标签的名称一致!
"""
# 创建输出文件夹
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")
test_dir = os.path.join(output_dir, "test")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 获取文件列表
files = os.listdir(os.path.join(input_dir, "A"))
num_files = len(files)
num_train = int(num_files * train_ratio)
num_val = int(num_files * val_ratio)
num_test = num_files - num_train - num_val
# 随机打乱文件列表
random.shuffle(files)
# 划分数据集并复制文件
for i, file in enumerate(tqdm(files, desc="copying files")):
if i < num_train:
dest_dir = train_dir
elif i < num_train + num_val:
dest_dir = val_dir
else:
dest_dir = test_dir
# 复制A、B、label文件夹中的内容
for folder in ["A", "B", "label"]:
src_path = os.path.join(input_dir, folder, file)
dest_path = os.path.join(dest_dir, folder, file)
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
shutil.copy(src_path, dest_path)
print("Dataset split completed.")
if __name__ == "__main__":
# 输入和输出文件夹路径
input_dir = r"ccc"
output_dir = r"ccc_out"
# 划分数据集并保存到输出文件夹
split_dataset(input_dir, output_dir)