自定义数据集的训练测试txt文件生成

在自己数据集上根据指定比例,生成测试集和训练集,并写入txt文件

import os
import numpy as np

abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录

import numpy as np


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1	
    nums = np.ones(5378, dtype=int)
    test_size = int(0.8 * len(nums))
    nums[:test_size] = 0
    np.random.shuffle(nums)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()

发现一点问题,不能简单的根据一个随机数进行划分,存在一种可能是在某一个类中没有取到训练或者测试数据,有问题,因此还是需要进行遍历每一个文件夹,有了如下的更新:

import os
import numpy as np
import random
abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

def get_train_test(dir,split_rate):
    train_test_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            print('imagDir', imagDir)
            images = os.listdir(imagDir)
            num = len(images)
            eval_index = random.sample(images, k=int(num * split_rate))
            for index, image in enumerate(images):
                if image in eval_index:
                    # 将分配至验证集中的文件复制到相应目录
                    train_test_list.append(0)
                else:
                    # 将分配至训练集中的文件复制到相应目录
                    train_test_list.append(1)
            print()
    return train_test_list
if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1
    # nums = np.zeros(5378, dtype=int)
    # test_size = int(0.8 * len(nums))
    # nums[:test_size] = 1
    # np.random.shuffle(nums)
    # file_handle = open('train_test_split.txt', mode='w')
    # for i,j in enumerate(nums):
    #     file_handle.write('{} '.format(i + 1))
    #     # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     file_handle.write('{} '.format(j)) # train_test_split.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V2
    nums = get_train_test(dirname,0.2)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()
    n = str(nums).count('1')
    m = str(nums).count('0')
    print('nums',nums)
    print('train',n)
    print('value',m)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值