darknet生成训练需要的train.txt文件和valid.txt文件脚本

功能:darknet生成训练需要的train.txt文件和valid.txt文件脚本,供训练。

# coding:utf-8
import os  
import io  
import math
import sys
import random
import argparse

from collections import namedtuple, OrderedDict  
#设置测试集比例
valid_percent=0.04

label_names = ['person','car','bus','truck']

def get_files(dir, suffix): 
    res = []
    for root, directory, files in os.walk(dir): 
        for filename in files:
            name, suf = os.path.splitext(filename) 
            if suf == suffix:
                #res.append(filename)
                res.append(os.path.join(root, filename))
    return res
def gbbox_iou(box1, box2):
    b1_x1, b1_y1, b1_x2, b1_y2 = box1
    b2_x1, b2_y1, b2_x2, b2_y2 = box2

    inter_rect_x1 = max(b1_x1, b2_x1)
    inter_rect_y1 = max(b1_y1, b2_y1)
    inter_rect_x2 = min(b1_x2, b2_x2)
    inter_rect_y2 = min(b1_y2, b2_y2)
   
    inter_width = inter_rect_x2 - inter_rect_x1 + 1
    inter_height = inter_rect_y2 - inter_rect_y1 + 1
    if inter_width > 0 and inter_height > 0:  
        inter_area = inter_width * inter_height
        #iou
        b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
        #iou = inter_area / (b1_area + b2_area - inter_area)
        iou = inter_area / b1_area

    else:
        iou = 0
    return iou

def convert_dataset(list_path, output_file):

    # 读取目录里面所有的 txt标记文件 列表
    label_list = get_files(list_path, '.txt')
    total_label_len = len(label_list)
    random.shuffle(label_list)
    print('total_label_len', total_label_len)
    error_count = 0
    train_output_file = os.path.join(output_file,'train.txt')
    valid_output_file = os.path.join(output_file,'valid.txt')
    fp_train=open(train_output_file,'w')
    fp_valid=open(valid_output_file,'w')

    for i in range(0, total_label_len):
        sys.stdout.write('\r>> Calculating {}/{} error{}'.format(
            i + 1, total_label_len, error_count))
        sys.stdout.flush()
        '''
        #生成valid.txt文件
        if i % 10 != 2:
            continue
        # 单个Label txt文件读取
        '''
        label_file = label_list[i]
        file_name, type_name = os.path.splitext(label_file)
        image_path = file_name + '.jpg'
        if type_name != '.txt' or not os.path.exists(image_path):
            error_count += 1
            print("error_file: ",label_file.encode('UTF-8', 'ignore').decode('UTF-8'))
            continue
        fd = open(label_file, 'r')
        lines = [line.split() for line in fd]
        fd.close()
        error_id = 0
        for line in lines:  
            class_index = int(line[0])
            xmins = float(line[1]) - float(line[3]) / 2
            ymins = float(line[2]) - float(line[4]) / 2
            xmaxs = float(line[1]) + float(line[3]) / 2                             
            ymaxs = float(line[2]) + float(line[4]) / 2 
            if  float(line[3])<=0 or float(line[4]) <= 0 :
                error_id = 1
                print('\n error index: ', class_index, 'label_file', label_file)
                continue
            if class_index >= 1:
                error_id = 1
                print('\n error index: ', class_index, 'label_file', label_file)
                continue
            # if xmins < 0 or ymins < 0 :
            #     error_id = 1
            #     print('\n error index: ', class_index, 'label_file', label_file)
            # if  ymaxs > 1  or xmaxs > 1 :
            #     print('\n error index: ', class_index, 'label_file', label_file)
            #     error_id = 1

        if error_id:
            continue
        # is_person_car = False            
        # bbox_num = len(lines)
        # for  i in range(0, bbox_num):

        #     if int(lines[i][0]) != 0:
        #         continue
        #     for j in range(0, bbox_num):

        #         if i==j or int(lines[j][0])==0:
        #             continue
        #         xmins = float(lines[i][1]) - float(lines[i][3]) / 2
        #         ymins = float(lines[i][2]) - float(lines[i][4]) / 2
        #         xmaxs = float(lines[i][1]) + float(lines[i][3]) / 2                             
        #         ymaxs = float(lines[i][2]) + float(lines[i][4]) / 2 

        #         xmins1 = float(lines[j][1]) - float(lines[j][3]) / 2
        #         ymins1 = float(lines[j][2]) - float(lines[j][4]) / 2
        #         xmaxs1 = float(lines[j][1]) + float(lines[j][3]) / 2                             
        #         ymaxs1 = float(lines[j][2]) + float(lines[j][4]) / 2 

        #         box1 = (xmins, ymins, xmaxs, ymaxs)
        #         box2 = (xmins1, ymins1, xmaxs1, ymaxs1)
        #         #过滤行人在车中
        #         iou = gbbox_iou(box1, box2)
        #         if iou > 0.99:
        #            is_person_car = True
        # if  is_person_car:
        #     continue
        if i < int(total_label_len * valid_percent):
            fp_valid.write(image_path)
            fp_valid.write('\n')
            print(" valid image_path: ", image_path)
        else:
            fp_train.write(image_path)
            fp_train.write('\n')
            print(" train image_path: ", image_path)

    print('total_label_len', total_label_len)
    fp_train.close()
    fp_valid.close()


def main(): 

    parser = argparse.ArgumentParser(prog='gen_label_list.py')
    parser.add_argument('--img-path', type=str, default='/root/zhangsong/fairworks/github/darknet-master/fireworks_yolov4_tiny/data/smoke', help='test path')
    parser.add_argument('--valid', type=str, default='fireworks_yolov4_tiny/data', help='*.txt path')
    opt = parser.parse_args() 

    print(opt.img_path, opt.valid)
    convert_dataset(opt.img_path, opt.valid)
  
  
if __name__ == '__main__':  
    main()
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值