深度学习:分类数据集的划分(python代码)
说明
这个代码是用来划分分类数据集的
"""
this is a code for split your datasets.for example, a floder which contain
some folders that is lable, so you need to split it be a train and val
"""
import argparse #参数解析器,可以使用
import os
import random
import shutil
import tqdm
def load_args():
#初始一个参数解析器的容器空间:parser
parser = argparse.ArgumentParser()
#添加参数
parser.add_argument('--path',type=str,default=r'C:\Users\Administrator\Desktop\train')
parser.add_argument('--radtio',type=float,default=0.8) #this is the spliting radtio
parser.add_argument('--dataset_type',type=str,default='train_val') #train-clas1
# -clas2
parser.add_argument('--save',type=str,default=r'C:\Users\Administrator\Desktop\data')
#...others args
#生成解析器接口
args = parser.parse_args()
return args
def split(args):
classname = os.listdir(args.path)
for class_folder in classname:
#对其中的一个类别进行划分
epath = os.path.join(args.path,class_folder) #路径
e_nums = len(os.listdir(epath)) #每一类的图像数量
train_nums = int(e_nums*args.radtio) # 训练集的数量
val_nums = e_nums - train_nums #验证集的数量
#随机挑选并复制粘贴
train_list = random.sample(range(0,e_nums),train_nums)
val_list = []
for i in range(0,e_nums):
if (i not in train_list):
val_list.append(i)
#复制粘贴
if not (os.path.exists(os.path.join(args.save,'val'))): #若没有建立该文件夹
os.mkdir(os.path.join(args.save, 'val'))
if not (os.path.exists(os.path.join(args.save,'train'))): #若没有建立该文件夹
os.mkdir(os.path.join(args.save, 'train'))
if not (os.path.exists(os.path.join(args.save,'val',class_folder))): #若没有建立该文件夹
os.mkdir(os.path.join(args.save,'val',class_folder))
if not (os.path.exists(os.path.join(args.save,'train',class_folder))): #若没有建立该文件夹
os.mkdir(os.path.join(args.save,'train',class_folder))
train_save = os.path.join(args.save,'train',class_folder)
val_save = os.path.join(args.save,'val',class_folder)
name = os.listdir(epath)
for i in train_list:
shutil.copy(os.path.join(epath,name[i]),os.path.join(train_save,name[i]))
for i in val_list:
shutil.copy(os.path.join(epath,name[i]),os.path.join(val_save,name[i]))
print('{}:已完成划分'.format(class_folder))
if __name__ == '__main__':
args = load_args()
split(args)