利用 python 在本地数据集创建训练集和测试集

根据自己的数据集,自动划分训练集、测试集

举个栗子:
	已经分好的文件:

在这里插入图片描述

	origin 文件夹有三类数据:good,bad,m,每类文件夹包含不同数量的图片,如下:

在这里插入图片描述

	需要生成数据集的文件:

在这里插入图片描述

	结果:根据设置好的比例,划分数据集和测试集

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

	完整代码:
import random
import os
import shutil
import glob
class get_data_sets():
    '''
    input_address:输入地址
    output_adddress:输出的地址
    train_ratio:训练集站比,(0,1)
    '''
    def __init__(self,input_address,output_adddress, train_ratio):
        self.__input_address = input_address
        self.__output_address = output_address
        self.__train_ratio = train_ratio
    def run(self):
        #获取数据种类
        class_address_list = glob.glob(self.__input_address + '\*')
        class_name_list = [ class_address.split('\\')[-1] for class_address in class_address_list ]
        #print
        print('数据分类为 {} \n训练集占比 {}'.format((class_name_list), self.__train_ratio))
        
        #新建训练、测试文件夹
        train_address = self.__output_address + '/train'
        test_address = self.__output_address + '/test'
        
        os.mkdir(train_address)
        os.mkdir(test_address)
        
        #在训练、测试文件夹 新建 类型 文件
        for class_name in class_name_list:
            os.mkdir(train_address + '/{}'.format(class_name))
            os.mkdir(test_address + '/{}'.format(class_name))
        
        #获取训练、测试数据
        class_num = [ len(os.listdir(all_class_address))  for all_class_address in class_address_list ] # 获取每类数据长度
        
        random.seed(2) #设置种子,保证每次分类一致
        train_address_list = [train_address + '/{}'.format(class_name)  for class_name in class_name_list]
        test_address_list = [test_address + '/{}'.format(class_name)  for class_name in class_name_list]

        #复制文件
        for i,num in enumerate(class_num):
            all_index = set(range(num))
            train_index = random.sample(all_index,int(self.__train_ratio*num))
            test_index = all_index - set(train_index)
            
            data_list = glob.glob(class_address_list[i] + '\*')
               
            for _ in train_index:
                shutil.copy(data_list[_], train_address_list[i])

            for _ in test_index:
                shutil.copy(data_list[_], test_address_list[i] )
        
        print('创建完成')
        


	运行代码:
input_address = r"D:\A_test\csdn_test\data_sets\origin"
output_address = "D:\A_test\csdn_test\data_sets\data"
a = get_data_sets(input_address,output_address,0.6)
a.run()

在这里插入图片描述

  • 3
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值