# -*- coding: utf-8 -*-
"""
将数据集划分为训练集,测试集
"""
import os
import random
import shutil
import math
# 创建保存图像的文件夹
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
random.seed(2021) # 随机种子
def split_dataset(imgs_dir, train_dir, test_dir, rate):
imgs_dir = imgs_dir
makedir(train_dir)
makedir(test_dir)
imgs = os.listdir(imgs_dir) # 获取图片的数量
img_count = len(imgs)
random.shuffle(imgs) # 打乱图片顺序
training_dataset_numbers = math.ceil(img_count * rate) # 向上取整
print(training_dataset_numbers)
for i in range(training_dataset_numbers):
target_path = os.path.join(train_dir, imgs[i]) # 指定目标保存路径
src_path = os.path.join(imgs_dir, imgs[i]) # 指定目标原图像路径
shutil.copy(src_path, target_path) # 移动图像
for i in range(training_dataset_numbers, img_count):
target_path = os.path.join(test_dir, imgs[i]) # 指定目标保存路径
src_path = os.path.join(imgs_dir, imgs[i]) # 指定目标原图像路径
shutil.copy(src_path, target_path) # 移动图像
if __name__ == "__main__":
imgs_dir = 'A4/label'
train_dir = 'A4/A4_train_label'
test_dir = 'A4/A4_test_label'
rate = 0.8 # 划分数据集的比例
split_dataset(imgs_dir, train_dir, test_dir ,rate)