json格式数据划分训练集、验证集
用法:可用于COCO格式的数据集的划分,还可添加测试集或者多个训练+验证集。根据自身需求,修改一下代码就好啦~
split_train_val.py
执行python split_train_val.py --img_path xxx --json_file *.json --output xxx
即可。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 将一个文件夹下图片按比例分在两个文件夹下
import os
import random
from shutil import copy2
import json
import argparse
def main(args):
all_data = os.listdir(args.img_path) # (图片文件夹)
random.seed(1)
random.shuffle(all_data) # 第一次打乱
all_data_img = []
for i in all_data:
if i.endswith(".jpg"):
all_data_img.append(i)
num_all_data = len(all_data_img)
print("num_all_data: " + str(num_all_data))
index_list = list(range(num_all_data))
random.seed(2)
random.shuffle(index_list) # 第二次打乱
num = 0
trainDir = os.path.join(args.img_path, "train") # (将训练集放在这个文件夹下)
if not os.path.exists(trainDir):
os.mkdir(trainDir)
validDir = s.path.join(args.img_path, "val") # (将验证集放在这个文件夹下)
if not os.path.exists(validDir):
os.mkdir(validDir)
train_list = []
val_list = []
for i in index_list:
fileName = os.path.join(args.img_path, all_data_img[i])
if num < num_all_data * 0.5: # 这里可是设置train,val的比例
train_list.append(all_data_img[i].split('.')[0])
copy2(fileName, os.path.join(trainDir, all_data_img[i]))
else:
val_list.append(all_data_img[i].split('.')[0])
copy2(fileName, os.path.join(validDir, all_data_img[i]))
num += 1
print("train_nums", len(train_list))
print("val_nums", len(val_list))
data = json.load(open(args.json_file))
train_json_dict = {
"images":[],
"annotations":[],
"categories":[],
"type": "instances"
}
val_json_dict = {
"images":[],
"annotations":[],
"categories":[],
"type": "instances"
}
# images
for i in data['images']:
if i['id'] in train_list:
train_json_dict['images'].append(i)
if i['id'] in val_list:
val_json_dict['images'].append(i)
# annotations
for j in data['annotations']:
j['category_id'] -= 1 # 类别从0开始
if j["image_id"] in train_list:
train_json_dict['annotations'].append(j)
if j["image_id"] in val_list:
val_json_dict['annotations'].append(j)
# categories ,类别从0开始
for k in data['categories']:
k['id'] -= 1
train_json_dict['categories'].append(k)
val_json_dict['categories'].append(k)
with open(os.path.join(args.output, "train.json"), "w") as f:
json.dump(train_json_dict, f, indent=2)
with open(os.path.join(args.output, "val.json"), "w") as f:
json.dump(val_json_dict, f, indent=2)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Start convert.')
parser.add_argument('--img_path', type=str, help='raw images path')# json文件路径
parser.add_argument('--json_file', type=str, help='json file path')# json文件路径
parser.add_argument('--output', type=str, help='output path', default='')# 输出的 txt 文件路径
args = parser.parse_args()
main(args)