```python
import os
import shutil
import random
def split_dataset(origin_path, new_dataset_path, split_ratio):
'''
:param origin_path:存放所有数据的路径
:param new_dataset_path:新数据集的路径
:param split_ratio: 存储训练集/测试集/验证集的划分比例的字典
:return:
'''
# 用了os.walk函数而非os.listdir,是因为后者只能返回一级目录下的文件,而前者则可以将子目录下所有文件返回。
for root, dirs, files in os.walk(origin_path, topdown=True):
print(root, dirs, files)
# 打乱数据
random.shuffle(files)
# 原数据的总长度
origin_len = len(files)
# 训练集、测试集、验证集的数据个数
len_train = origin_len * split_ratio['train']
len_test = origin_len * split_ratio['test']
len_eval = origin_len - len_train - len_test
if not os.path.exists(new_dataset_path):
os.mkdir(new_dataset_path)
# 存放子数据集的路径
train_path = os.path.join(new_dataset_path, 'train')
test_path = os.path.join(new_dataset_path, 'test')
eval_path = os.path.join(new_dataset_path, 'eval')
for sub_dataset_path in (train_path, test_path, eval_path):
if not os.path.exists(sub_dataset_path):
os.mkdir(sub_dataset_path)
# 遍历每个文件
for idx, file in enumerate(files):
filename = os.path.join(root, file)
if idx < len_train:
shutil.copyfile(filename, os.path.join(train_path, file))
elif idx - len_train < len_test:
shutil.copyfile(filename, os.path.join(test_path, file))
else:
shutil.copyfile(filename, os.path.join(eval_path, file))
def set_split_radio(train_ratio, test_ratio, eval_ratio):
split_ratio = dict()
train_ratio, test_ratio = train_ratio / (train_ratio + eval_ratio + test_ratio), test_ratio / (
train_ratio + eval_ratio + test_ratio)
eval_ratio = 1 - train_ratio - test_ratio
split_ratio['train'] = train_ratio
split_ratio['test'] = test_ratio
split_ratio['eval'] = eval_ratio
return split_ratio
# 原数据存放的路径
origin_path = "D:/Daily Code/Python Code/New_Folder/cycled_csv_by_day"
# 划分后数据集的路径
new_dataset_path = os.path.join(os.getcwd(), 'split_dataset')
# 按训练集:测试集:验证集=7:2:1的比例划分数据集
split_ratio = set_split_radio(7, 2, 1)
# 划分数据集
split_dataset(origin_path, new_dataset_path, split_ratio)