2021-06-30

python深度学习数据集划分详解

import os
import random
import sys
import shutil
#abspath表示此时本文件所在的位置,dirname表示上级目录所在是位置
BASE_DIR=os.path.dirname(os.path.abspath(__file__))
print(BASE_DIR)
#如果我们创建的新路径不存在,则创建该路径
def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

if __name__=='__main__':
    #路径拼接,创建一系列文件夹
    dataset_dir = os.path.abspath(os.path.join(BASE_DIR,"data", "RMB_data"))
    print(dataset_dir)
    split_dir = os.path.abspath(os.path.join(BASE_DIR, "data", "rmb_split"))
    print(split_dir)
    train_dir = os.path.join(split_dir, "train")
    print(train_dir)
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")
    if not os.path.exists(dataset_dir):
        raise Exception("\n{}不存在,请下载 02-01-数据-RMB_data.rar 放到\n{} 下,并解压即可".format(
            dataset_dir, os.path.dirname(dataset_dir)))
    #训练集,验证集,测试集划分的比例
    train_pct=0.8
    valid_pct=0.1
    test_pct=0.1
    #root指的是当前所在的文件夹路径,dirs是当前文件夹路径下的文件夹列表,files是当前文件夹路径下的文件列表。
    for root,dirs,file in os.walk(dataset_dir):
        #遍历每一个文件夹列表,获取文件夹下的每一个文件
        for sub_dir in dirs:
            imgs=os.listdir(os.path.join(root,sub_dir))
            #将.jpg文件过滤出来
            imgs=list(filter(lambda x: x.endswith('jpg'),imgs))
            #打乱顺序
            random.shuffle(imgs)
            #图片数量
            img_count=len(imgs)
            #训练集数量
            train_point=int(img_count*train_pct)
            #验证集数量
            valid_point=int(img_count*(train_pct+valid_pct))
            if img_count==0:
                print("{}目录下,无图片,请检查".format(os.path.join(root, sub_dir)))
                sys.exit(0)
            #将图片存入对应的文件夹中
            for i in range(img_count):
                if i<train_point:
                    out_dir=os.path.join(train_dir,sub_dir)

                elif i<valid_point:
                    out_dir=os.path.join(valid_dir,sub_dir)

                else:
                    out_dir = os.path.join(test_dir, sub_dir)
                #创建对应的路径
                makedir(out_dir)
                target_path=os.path.join(out_dir,imgs[i])
                src_path=os.path.join(dataset_dir,sub_dir,imgs[i])
                #将前者复制到后者中
                shutil.copy(src_path,target_path)
            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point - train_point,
                                                                 img_count - valid_point))
            print("已在 {} 创建划分好的数据\n".format(out_dir))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值