功能:darknet生成训练需要的train.txt文件和valid.txt文件脚本,供训练。
# coding:utf-8
import os
import io
import math
import sys
import random
import argparse
from collections import namedtuple, OrderedDict
#设置测试集比例
valid_percent=0.04
label_names = ['person','car','bus','truck']
def get_files(dir, suffix):
res = []
for root, directory, files in os.walk(dir):
for filename in files:
name, suf = os.path.splitext(filename)
if suf == suffix:
#res.append(filename)
res.append(os.path.join(root, filename))
return res
def gbbox_iou(box1, box2):
b1_x1, b1_y1, b1_x2, b1_y2 = box1
b2_x1, b2_y1, b2_x2, b2_y2 = box2
inter_rect_x1 = max(b1_x1, b2_x1)
inter_rect_y1 = max(b1_y1, b2_y1)
inter_rect_x2 = min(b1_x2, b2_x2)
inter_rect_y2 = min(b1_y2, b2_y2)
inter_width = inter_rect_x2 - inter_rect_x1 + 1
inter_height = inter_rect_y2 - inter_rect_y1 + 1
if inter_width > 0 and inter_height > 0:
inter_area = inter_width * inter_height
#iou
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
#iou = inter_area / (b1_area + b2_area - inter_area)
iou = inter_area / b1_area
else:
iou = 0
return iou
def convert_dataset(list_path, output_file):
# 读取目录里面所有的 txt标记文件 列表
label_list = get_files(list_path, '.txt')
total_label_len = len(label_list)
random.shuffle(label_list)
print('total_label_len', total_label_len)
error_count = 0
train_output_file = os.path.join(output_file,'train.txt')
valid_output_file = os.path.join(output_file,'valid.txt')
fp_train=open(train_output_file,'w')
fp_valid=open(valid_output_file,'w')
for i in range(0, total_label_len):
sys.stdout.write('\r>> Calculating {}/{} error{}'.format(
i + 1, total_label_len, error_count))
sys.stdout.flush()
'''
#生成valid.txt文件
if i % 10 != 2:
continue
# 单个Label txt文件读取
'''
label_file = label_list[i]
file_name, type_name = os.path.splitext(label_file)
image_path = file_name + '.jpg'
if type_name != '.txt' or not os.path.exists(image_path):
error_count += 1
print("error_file: ",label_file.encode('UTF-8', 'ignore').decode('UTF-8'))
continue
fd = open(label_file, 'r')
lines = [line.split() for line in fd]
fd.close()
error_id = 0
for line in lines:
class_index = int(line[0])
xmins = float(line[1]) - float(line[3]) / 2
ymins = float(line[2]) - float(line[4]) / 2
xmaxs = float(line[1]) + float(line[3]) / 2
ymaxs = float(line[2]) + float(line[4]) / 2
if float(line[3])<=0 or float(line[4]) <= 0 :
error_id = 1
print('\n error index: ', class_index, 'label_file', label_file)
continue
if class_index >= 1:
error_id = 1
print('\n error index: ', class_index, 'label_file', label_file)
continue
# if xmins < 0 or ymins < 0 :
# error_id = 1
# print('\n error index: ', class_index, 'label_file', label_file)
# if ymaxs > 1 or xmaxs > 1 :
# print('\n error index: ', class_index, 'label_file', label_file)
# error_id = 1
if error_id:
continue
# is_person_car = False
# bbox_num = len(lines)
# for i in range(0, bbox_num):
# if int(lines[i][0]) != 0:
# continue
# for j in range(0, bbox_num):
# if i==j or int(lines[j][0])==0:
# continue
# xmins = float(lines[i][1]) - float(lines[i][3]) / 2
# ymins = float(lines[i][2]) - float(lines[i][4]) / 2
# xmaxs = float(lines[i][1]) + float(lines[i][3]) / 2
# ymaxs = float(lines[i][2]) + float(lines[i][4]) / 2
# xmins1 = float(lines[j][1]) - float(lines[j][3]) / 2
# ymins1 = float(lines[j][2]) - float(lines[j][4]) / 2
# xmaxs1 = float(lines[j][1]) + float(lines[j][3]) / 2
# ymaxs1 = float(lines[j][2]) + float(lines[j][4]) / 2
# box1 = (xmins, ymins, xmaxs, ymaxs)
# box2 = (xmins1, ymins1, xmaxs1, ymaxs1)
# #过滤行人在车中
# iou = gbbox_iou(box1, box2)
# if iou > 0.99:
# is_person_car = True
# if is_person_car:
# continue
if i < int(total_label_len * valid_percent):
fp_valid.write(image_path)
fp_valid.write('\n')
print(" valid image_path: ", image_path)
else:
fp_train.write(image_path)
fp_train.write('\n')
print(" train image_path: ", image_path)
print('total_label_len', total_label_len)
fp_train.close()
fp_valid.close()
def main():
parser = argparse.ArgumentParser(prog='gen_label_list.py')
parser.add_argument('--img-path', type=str, default='/root/zhangsong/fairworks/github/darknet-master/fireworks_yolov4_tiny/data/smoke', help='test path')
parser.add_argument('--valid', type=str, default='fireworks_yolov4_tiny/data', help='*.txt path')
opt = parser.parse_args()
print(opt.img_path, opt.valid)
convert_dataset(opt.img_path, opt.valid)
if __name__ == '__main__':
main()