按比例划分数据集coco。更改标签类别号

import os,shutil
import math
import random

#混合数据集划分
# path = ""
# destinationPath = ""
# os.chdir(path)
# datapath = os.getcwd()




class DataProcess:
    def __init__(self):
        self.sourcepath = datapath
        # dataname 当前文件的文件名列表
        # self.desPath = destinationPath
        self.train = 0.6
        self.valid = 0.2
        self.test = 0.2
        # 更改类别号方法参数
        # self.NewClsNum = NewClassNum
        # self.ChangeClsNum_path = Path2

    def move_file(self,key, desfile):
        if not os.path.exists(desfile):
            print("desfile not exist,so create the dir")
            os.makedirs(desfile, 1)
        if os.path.exists(desfile):
            for file in os.listdir(self.sourcepath):
                if key in file:
                    shutil.copy(self.sourcepath + '/' + file, desfile + '/' + file)

    '''
    混合数据集划分,根据 .txt 文件划分
    '''
    def split_to_single_class(self,sourcepath,destpath):
        datanames = os.listdir(sourcepath)
        for file in datanames:
            if file.endswith('.txt'):
                filepath = os.path.join(sourcepath, file)
                with open(filepath, 'r') as f:
                    destinationfile = f.read(1)
                    # print(destinationfile)
                    f.close()
            print(destinationfile)
            desfile = os.path.join(destpath, destinationfile)
            key = file[:-5]  # 去除文件名后缀
            self.move_file(key, desfile)
            pass
        pass

    '''
    更改数据集类别号
    '''
    def change_cluss_num(self,path,newclassnum):
        datanames = os.listdir(path)
        for file in datanames:
            if file.endswith('.txt'):
                txt_item = os.path.join(path, file)
                with open(txt_item,'r') as f:
                    lines = f.readlines()
                    # print(lines)
                    f.close()
                with open(txt_item,'w') as f:
                    for line in lines:
                        # line.strip()删除首尾空格和换行符,.split()按空格分割字符串
                        line_split = line.strip().split()
                        line_split[0] = newclassnum
                        f.write(
                            line_split[0] + ' ' +
                            line_split[1] + ' ' +
                            line_split[2] + ' ' +
                            line_split[3] + ' ' +
                            line_split[4] + '\n'
                        )
                f.close()
    '''
    按比例划分数据集,
    输入为各个单个类的数据集,
    '''
    def make_dirs(self , dest_path , images_or_labels):
        dirlist = ['train','valid','test']
        if not os.path.exists(dest_path + '/' + str(images_or_labels)):
            print("desfile not exist,so create the dir")
            for dir in dirlist:
                os.makedirs(dest_path + '/' + str(images_or_labels) +'/'+ dir)
    def copy_splited_data(self,filelist,sourcepath,destpath):
        # sourcelist = os.listdir(sourcepath)
        for file in filelist:
            filepath = os.path.join(sourcepath, file)
            if filepath:
                shutil.copy(filepath,destpath)

    '''
    按比例划分数据集,输入为图片路径,标签路径,目标文件夹。图片和标签路径可相同
    '''
    def split_train_dataset(self,source_img_path,source_label_path,dest_path):
        #按比例划分训练集0.6,验证集0.2,和测试集0.2
        sourceimgnames = os.listdir(source_img_path)
        sourcelabelnames = os.listdir(source_label_path)
        key = []
        imgnames = []
        labelnames = []
        for file in sourceimgnames:
            # 获取源图片目录下的图片名字列表
            filename = os.path.join(source_img_path,file)
            if os.path.isfile(filename):
                if file.endswith('.txt'):
                    labelnames.append(file)
                    key.append(file[:-4])
        for file2 in sourcelabelnames:
            file2names = os.path.join(source_label_path,file2)
            if os.path.isfile(file2names):
                if file2[-4:] != '.txt':
                    imgnames.append(file2)

        if imgnames:self.make_dirs(dest_path,'images')
        if labelnames:
            self.make_dirs(dest_path,'labels')
            rantestlabels = []
            rantestimages = []
            rantrainlabels = []
            rantrainimages = []
            ranvalidimages = []
            rantwolabels = random.sample(labelnames,int((self.valid+self.test)*len(labelnames)))
            num = int((self.valid / (self.valid + self.test)) * len(rantwolabels))
            ranvalidlabels = random.sample(rantwolabels,num)#随机验证集列表
            for ranname3 in ranvalidlabels:
                rankey3 = ranname3[:-4]
                for ranimg3 in imgnames:
                    if rankey3 in ranimg3:
                        ranvalidimages.append(ranimg3)#随机验证集标签列表

            for ranname1 in rantwolabels:
                if not ranname1 in ranvalidlabels:
                    rankey1 = ranname1[:-4]
                    for ranimg1 in imgnames:
                        if rankey1 in ranimg1:
                            rantestimages.append(ranimg1)#随机测试集图片列表
                    rantestlabels.append(ranname1)#随机测试集标签列表

            for ranname2 in labelnames:
                if not ranname2 in rantwolabels:
                    rankey2 = ranname2[:-4]
                    for ranimg2 in imgnames:
                        if rankey2 in ranimg2:
                            rantrainimages.append(ranimg2)
                    rantrainlabels.append(ranname2)#随机训练集列表
            self.copy_splited_data(rantrainimages,source_img_path,dest_path+'/'+'images'+'/'+'train')
            self.copy_splited_data(ranvalidimages,source_img_path,dest_path+'/'+'images'+'/'+'valid')
            self.copy_splited_data(rantestimages,source_img_path,dest_path+'/'+'images'+'/'+'test')
            self.copy_splited_data(rantrainlabels, source_label_path, dest_path + '/' + 'labels' + '/' + 'train')
            self.copy_splited_data(ranvalidlabels, source_label_path, dest_path + '/' + 'labels' + '/' + 'valid')
            self.copy_splited_data(rantestlabels, source_label_path, dest_path + '/' + 'labels' + '/' + 'test')






if __name__ == '__main__':
    '''
    datapro = DataProcess()
    # 更改数据集类别号 例子
    NewClassNum = '1'
    Path2 = "D:\Learning_software\deeplearning\learnV2_Pytorch\chapter3"
    datapro.change_cluss_num(Path2,NewClassNum)
    # datapro.split_data() 例子
    dest_path =r"D:\masterPROJECT\laser_weeding\weed_dataset\testsplitAG\dest"
    source_img_path =r"D:\masterPROJECT\laser_weeding\weed_dataset\testsplitAG\1"
    source_label_path =r"D:\masterPROJECT\laser_weeding\weed_dataset\testsplitAG\1"
    datapro.split_train_dataset(source_img_path,source_label_path,dest_path)
    '''

    classlist = []
    datapath = r"D:\masterPROJECT\laser_weeding\weed_dataset\dataclassfication"
    classdirs = os.listdir(datapath)
    for dir in classdirs:
        classdir = os.path.join(datapath,dir)
        if os.path.isdir(classdir):
            classlist.append(dir)
    class_txt = os.path.join(datapath,'classes.txt')
    with open(class_txt,'w') as f:
        for line in classlist:
            f.write(line+'\n')
        f.close()

    dataprocess = DataProcess()
    #根据classlist索引值,更改类别号
    destpath = r"D:\masterPROJECT\laser_weeding\weed_dataset\testsplitAG"
    for cls in classlist:
        num = classlist.index(cls)
        num = str(num)
        clspath = os.path.join(datapath,cls)
        print(num)
        dataprocess.change_cluss_num(clspath,num)
        dataprocess.split_train_dataset(clspath,clspath,destpath)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值