#训练集
train_file = open(train_path,"w")
#验证集
val_file = open(val_path,"w")
#测试集
test_file = open(test_path,"w")
#注释集
anno = open(annotation_path, 'r')
result = []
#存放注释集中的数据
my_dict = {}
#存放注释集中的数据的个数
cnt = 0
for line in anno:
my_dict[cnt]=line
cnt+=1
totalnum = cnt
#7:2:1分
train_num = int(totalnum * 0.7)
val_num = int(totalnum * 0.2)
test_num = totalnum-train_num-val_num
#生成不重复的集合
test_set = set()
val_set = set()
train_set = set()
#随即分类
while(len(test_set) < test_num):
x = random.randint(0,totalnum)
if x not in test_set :
test_set.add(x)
while(len(val_set) < val_num):
x = random.randint(0,totalnum)
if x in test_set :
continue
if x not in val_set :
val_set.add(x)
for x in range(totalnum):
if x in test_set or x in val_set:
continue
else:
train_set.add(x)
index = 0
for i in range(cnt):
strs = my_dict[i]
if i in train_set:
train_file.write(strs)
elif i in val_set:
val_file.write(strs)
else:
test_file.write(strs)
index+=1
train_file.close
val_file.close
test_file.close
将数据集随机7:2:1分
最新推荐文章于 2024-05-23 01:12:57 发布