Python 文件
import argparse
def parse_args():
parser = argparse.ArgumentParser(" ")
parser.add_argument(
'--source_dir',
type=str,
help='None')
parser.add_argument(
'--target_dir',
type=str,
default= None ,
help='作者太懒')
parser.add_argument(
'--validation_nums',
type=float,
default= 0.2,
help='作者太懒')
parser.add_argument(
'--test_nums',
type=float,
default= 0.1,
help='作者太懒')
args = parser.parse_args()
return args
# 数据拷贝分割函数
# source_dir 源文件文件夹
# target_dir 目标文件夹 为None 则在源文件夹同级目录下创建名为datasets的文件夹
# validation_nums 验证数据百分比 默认 0.2
# test_nums 测试数据百分比 默认 0.1
def SplitTrainData(source_dir,target_dir=None,validation_nums=0.2,test_nums=0.1):
import os
train_class = os.listdir(source_dir)
#删除linux隐藏文件
train_class_nums=len(train_class)
while i < train_class_nums:
if "." in train_class[i]:
train_class.pop(i)
i = i-1
train_class_nums = train_class_nums - 1
i = i + 1
if target_dir==None:
target_dir = source_dir+u'/..'
target_dir = target_dir+u'/datasets'
datasets_dir = [] #原始数据路径
for i in train_class:
datasets_dir.append(source_dir+'/'+i)
try:
train_dir = []
validation_dir = []
test_dir = []
for i in train_class:
train_dir.append(target_dir+u'/train/'+i)
validation_dir.append(target_dir+u'/validation/'+i)
test_dir.append(target_dir+u'/test/'+i)
os.mkdir(target_dir) # 创建训练文件夹
os.mkdir(target_dir+u'/train')
os.mkdir(target_dir+u'/validation')
os.mkdir(target_dir+u'/test')
for i in train_dir:
os.mkdir(i)
for i in validation_dir:
os.mkdir(i)
for i in test_dir:
os.mkdir(i)
except:
print("make file error")
import shutil
for i in train_class:
files_names = os.listdir(source_dir+u'/'+i)
validation_starts = int(len(files_names)-len(files_names)*(validation_nums+test_nums))
test_starts = int(len(files_names)-len(files_names)*test_nums)
for file_name in files_names[0:validation_starts]:
shutil.copy(source_dir+u'/'+i+'/'+file_name,target_dir+u'/train/'+i)
for file_name in files_names[validation_starts:test_starts]:
shutil.copy(source_dir+u'/'+i+'/'+file_name,target_dir+u'/validation/'+i)
for file_name in files_names[test_starts:]:
shutil.copy(source_dir+u'/'+i+'/'+file_name,target_dir+u'/test/'+i)
print(i+" has been copy success")
print(i+": train " + str(validation_starts)+" validation " +str(test_starts-validation_starts) + " test "+str(len(files_names)-test_starts))
args = parse_args()
source_dir = args.source_dir
target_dir = args.target_dir
validation_nums = args.validation_nums
test_nums = args.test_nums
if __name__ == '__main__':
SplitTrainData(source_dir ,target_dir ,validation_nums ,test_nums )
CMD带参数调用
!python /home/aistudio/work/split.py --source_dir=/home/aistudio/catVSdog/train