转换IphotDraw产生的xml文件至Yolov3训练所需要的格式

    最近在尝试目标检测的YOLOv3模型,首先得准备训练数据,自己找了一些图片标注了下,标注的工具是Iphotodraw,产生的文件是xml格式的,所以写了一个python3的脚本把xml文件里的数据转换成yolov3需要的格式。

 

  

#-*- coding:utf-8 _*-
"""
@author:xxx
@file: get_yolov3_train_data.py.py
@time: 2020/02/28
"""
#coding:utf-8
#@Time:2017/6/16 19:37
#@author: Steve
import os
import re
import math
import string
from tqdm import tqdm
import numpy as np
# 提取xml文件中的矩形框四个角点跟标签


def get_new_coord(center_coord,ori_coord,rotate_angle):
    x_new = (ori_coord[0]-center_coord[0])*math.cos((rotate_angle/180.)*math.pi)+(ori_coord[1]-center_coord[1])*math.sin((rotate_angle/180.)*math.pi)+center_coord[0]
    y_new = (ori_coord[1]-center_coord[1])*math.cos((rotate_angle/180.)*math.pi)-(ori_coord[0]-center_coord[0])*math.sin((rotate_angle/180.)*math.pi)+center_coord[1]
    return int(x_new),int(y_new)


def get_coord_range(item):
    list_location = []
    for i in range(len(item)):
        if (item[i] == '"'):
            list_location.append(i + 1)
    X=item[list_location[0]:list_location[1]-1]
    Y=item[list_location[2]:list_location[3]-1]
    Width=item[list_location[4]:list_location[5]-1]
    Height=item[list_location[6]:list_location[7]-1]
    return X,Y,Width,Height


def GetItemLocation(xml_file):
    fid=open(xml_file,"r",encoding="utf-8")
    list=[]
    for line in fid.readlines():
        line=line.replace("\n","")
        line=line.replace(" ","")
        list.append(line)

    str1 = "".join(list)

    poly_str = re.findall("Polygon.*?</Points>",str1)
    for item in poly_str:
        str1 = str1.replace(item,'')
    poly_str = ''.join(poly_str)

    label_poly_list = re.findall("<Text>.*?</Text>",poly_str)
    coord_poly_list = re.findall("<Points>.*?</Points>",poly_str)

    poly_label_out = []
    for item_poly in label_poly_list:
        poly_label_out.append(item_poly[6:-7])

    coord_poly_out = []
    for coord_polys in coord_poly_list:
        coord_poly_X = re.findall("X=.*?Y",coord_polys)
        coord_poly_X_int = np.array([int(float(XX[3:-2])) for XX in coord_poly_X]).reshape(-1,1)
        coord_poly_Y= re.findall("Y=.*?/>",coord_polys)
        coord_poly_Y_int = np.array([int(float(YY[3:-3])) for YY in coord_poly_Y]).reshape(-1,1)
        points_poly = np.hstack((coord_poly_X_int,coord_poly_Y_int))
        coord_poly_out.append(points_poly)

    label_list=re.findall("<Text>.*?</Text>",str1)
    position_list=re.findall("<Extent.*?/>",str1)
    angle_list=re.findall("<Data.*?>",str1)

    # 取出文本框标签
    label_list_new = []
    for label in label_list:
        result_label = label[6:-7]
        label_list_new.append(result_label)

    # 取出文本框旋转角度
    angle_list_new = []
    for angle in angle_list:
        result_angle = float(angle.split('"')[-2])
        angle_list_new.append(result_angle)

    #取出左上角点 w h
    position_list_new = []
    for position in position_list:
        X, Y, Width, Height = get_coord_range(position)
        position_list_new.append([X, Y, Width, Height])


    # 得到标注框坐标
    point_list = []
    for i in range(len(position_list_new)):
        value = position_list_new[i]
        angle = angle_list_new[i]

        x1 = int(float(value[0]))  # 左上
        y1 = int(float(value[1]))
        x2 = int(float(value[0]) + float(value[2]))  # 右上
        y2 = int(float(value[1]))
        x3 = int(float(value[0]) + float(value[2]))  # 右下
        y3 = int(float(value[1]) + float(value[3]))
        x4 = int(float(value[0]))  # 左下
        y4 = int(float(value[1]) + float(value[3]))

        if not angle == 0:
            angle = -angle
            center_x = (x1 + x2 + x3 + x4) / 4
            center_y = (y1 + y2 + y3 + y4) / 4

            x1, y1 = get_new_coord([center_x, center_y], [x1, y1], angle)
            x2, y2 = get_new_coord([center_x, center_y], [x2, y2], angle)
            x3, y3 = get_new_coord([center_x, center_y], [x3, y3], angle)
            x4, y4 = get_new_coord([center_x, center_y], [x4, y4], angle)

        if int(x1) < 0:
            x1 = 0
        if (int(y1) < 0):
            y1 = 0
        if (int(y2) < 0):
            y2 = 0
        if (int(x4) < 0):
            x4 = 0

        point = [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]

        point_list.append(point)

    return label_list_new,point_list,coord_poly_out,poly_label_out

def get_yolov3_box(coord,img):
    ## coord = [x1,y1,x2,y2]
    h_factor, w_factor = img.shape[0:2]
    box_ratio = []
    ratio = ((coord[0]+coord[2])/2)/w_factor
    box_ratio.append(ratio)
    ratio = ((coord[1]+coord[3])/2)/h_factor
    box_ratio.append(ratio)
    ratio = (coord[2]-coord[0])/w_factor
    box_ratio.append(ratio)
    ratio = (coord[3]-coord[1])/h_factor
    box_ratio.append(ratio)
    return box_ratio


if __name__=='__main__':
    import glob
    import traceback
    import cv2
    import shutil
    path=r'/src/notebook/detect_id/detect'
    target_path_root = r'/src/notebook/detect_id/PyTorch-YOLOv3-master'
    data_name = 'SFZ'
    train_type = ['ID_Z','ID_F']
    if(not os.path.exists(os.path.join(target_path_root,'data',data_name))):
        os.mkdir(os.path.join(target_path_root,'data',data_name))
    if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'images'))):
        os.mkdir(os.path.join(target_path_root, 'data', data_name,'images'))
    if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'results'))):
        os.mkdir(os.path.join(target_path_root, 'data', data_name,'results'))
    if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'labels'))):
        os.mkdir(os.path.join(target_path_root, 'data', data_name,'labels'))
    fid_train = open(os.path.join(target_path_root, 'data', data_name,'train.txt'),'w+',encoding='utf-8')
    fid_val = open(os.path.join(target_path_root, 'data', data_name,'valid.txt'),'w+',encoding='utf-8')
    fid_type = open(os.path.join(target_path_root, 'data', data_name,'classes.names'),'w+',encoding='utf-8')
    for item in train_type:
        fid_type.write(item+'\n')
    fid_type.close()
    jpg_files=glob.glob(os.path.join(path,'*.jpg'))
    bar=tqdm(total=len(jpg_files))
    for file in jpg_files:
        img = cv2.imread(file)
        h,w = img.shape[0:2]
        img_name = file.split('/')[-1]
        
        bar.update(1)
        try:
            xml_file = file.replace('.jpg','_data.xml')
            label_list, position_list,coord_poly_out,poly_label_out  = GetItemLocation(xml_file)
            if(len(poly_label_out)==0):
                print(file)
                continue
            fid_box = open(os.path.join(target_path_root, 'data', data_name,'labels',img_name.split('.')[0]+'.txt'),'w+',encoding='utf-8')
            for ii in range(len(poly_label_out)):
                id_label = int(poly_label_out[ii])-1
                list_coord = [np.min(coord_poly_out[ii][:,0]),np.min(coord_poly_out[ii][:,1]),
                              np.max(coord_poly_out[ii][:, 0]), np.max(coord_poly_out[ii][:, 1])
                              ]
                list_coord = [np.max([list_coord[0],0]),np.max([list_coord[1],0]),
                          np.min([list_coord[2], w-1]), np.min([list_coord[3], h-1])]
                img = cv2.rectangle(img,(int(list_coord[0]),int(list_coord[1])),(int(list_coord[2]),int(list_coord[3])),(0,0,255))
                coord_ratio = get_yolov3_box(list_coord, img)
                coord_ratio = [str(x) for x in coord_ratio]
                coord_str =' '+ ' '.join(coord_ratio)
                fid_box.write(str(id_label)+coord_str+'\n')
                
            fid_train.write(os.path.join('data',data_name,'images',img_name)+'\n')
            fid_val.write(os.path.join('data',data_name,'images',img_name)+'\n')
            shutil.copy(file,os.path.join(target_path_root,'data',data_name,'images'))
            cv2.imwrite(os.path.join(target_path_root,'data',data_name,'results',img_name),img)
            fid_box.close()
            # 左上 左下 右下 右上
            IMIMkey = file.split('\\')[-1].replace('.jpg', '')
            with open(os.path.join(path, IMIMkey + '.txt'), 'w+', encoding='utf-8') as fid:
                for i, label in enumerate(label_list):
                    position = position_list[i]
                    flag = 1
                    if flag == 1:
                        fid.writelines(str(position[0][0]) + ',' + str(position[0][1]) + ',' +
                                       str(position[1][0]) + ',' + str(position[1][1]) + ',' +
                                       str(position[2][0]) + ',' + str(position[2][1]) + ',' +
                                       str(position[3][0]) + ',' + str(position[3][1]) +
                                        ',' + label +'\n')

        except:
            print(file)
            traceback.print_exc()
    fid_train.close()
    fid_val.close()

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值