划分数据集


```python
import os
import shutil
import random


def split_dataset(origin_path, new_dataset_path, split_ratio):
    '''

    :param origin_path:存放所有数据的路径
    :param new_dataset_path:新数据集的路径
    :param split_ratio: 存储训练集/测试集/验证集的划分比例的字典
    :return:
    '''
    # 用了os.walk函数而非os.listdir,是因为后者只能返回一级目录下的文件,而前者则可以将子目录下所有文件返回。
    for root, dirs, files in os.walk(origin_path, topdown=True):
        print(root, dirs, files)
        # 打乱数据
        random.shuffle(files)
        # 原数据的总长度
        origin_len = len(files)
        # 训练集、测试集、验证集的数据个数
        len_train = origin_len * split_ratio['train']
        len_test = origin_len * split_ratio['test']
        len_eval = origin_len - len_train - len_test

        if not os.path.exists(new_dataset_path):
            os.mkdir(new_dataset_path)
        # 存放子数据集的路径
        train_path = os.path.join(new_dataset_path, 'train')
        test_path = os.path.join(new_dataset_path, 'test')
        eval_path = os.path.join(new_dataset_path, 'eval')

        for sub_dataset_path in (train_path, test_path, eval_path):
            if not os.path.exists(sub_dataset_path):
                os.mkdir(sub_dataset_path)

        # 遍历每个文件
        for idx, file in enumerate(files):
            filename = os.path.join(root, file)
            if idx < len_train:
                shutil.copyfile(filename, os.path.join(train_path, file))
            elif idx - len_train < len_test:
                shutil.copyfile(filename, os.path.join(test_path, file))
            else:
                shutil.copyfile(filename, os.path.join(eval_path, file))


def set_split_radio(train_ratio, test_ratio, eval_ratio):
    split_ratio = dict()
    train_ratio, test_ratio = train_ratio / (train_ratio + eval_ratio + test_ratio), test_ratio / (
            train_ratio + eval_ratio + test_ratio)
    eval_ratio = 1 - train_ratio - test_ratio
    split_ratio['train'] = train_ratio
    split_ratio['test'] = test_ratio
    split_ratio['eval'] = eval_ratio
    return split_ratio


# 原数据存放的路径
origin_path = "D:/Daily Code/Python Code/New_Folder/cycled_csv_by_day"
# 划分后数据集的路径
new_dataset_path = os.path.join(os.getcwd(), 'split_dataset')
# 按训练集:测试集:验证集=7:2:1的比例划分数据集
split_ratio = set_split_radio(7, 2, 1)
# 划分数据集
split_dataset(origin_path, new_dataset_path, split_ratio)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值