最近在尝试目标检测的YOLOv3模型,首先得准备训练数据,自己找了一些图片标注了下,标注的工具是Iphotodraw,产生的文件是xml格式的,所以写了一个python3的脚本把xml文件里的数据转换成yolov3需要的格式。
#-*- coding:utf-8 _*-
"""
@author:xxx
@file: get_yolov3_train_data.py.py
@time: 2020/02/28
"""
#coding:utf-8
#@Time:2017/6/16 19:37
#@author: Steve
import os
import re
import math
import string
from tqdm import tqdm
import numpy as np
# 提取xml文件中的矩形框四个角点跟标签
def get_new_coord(center_coord,ori_coord,rotate_angle):
x_new = (ori_coord[0]-center_coord[0])*math.cos((rotate_angle/180.)*math.pi)+(ori_coord[1]-center_coord[1])*math.sin((rotate_angle/180.)*math.pi)+center_coord[0]
y_new = (ori_coord[1]-center_coord[1])*math.cos((rotate_angle/180.)*math.pi)-(ori_coord[0]-center_coord[0])*math.sin((rotate_angle/180.)*math.pi)+center_coord[1]
return int(x_new),int(y_new)
def get_coord_range(item):
list_location = []
for i in range(len(item)):
if (item[i] == '"'):
list_location.append(i + 1)
X=item[list_location[0]:list_location[1]-1]
Y=item[list_location[2]:list_location[3]-1]
Width=item[list_location[4]:list_location[5]-1]
Height=item[list_location[6]:list_location[7]-1]
return X,Y,Width,Height
def GetItemLocation(xml_file):
fid=open(xml_file,"r",encoding="utf-8")
list=[]
for line in fid.readlines():
line=line.replace("\n","")
line=line.replace(" ","")
list.append(line)
str1 = "".join(list)
poly_str = re.findall("Polygon.*?</Points>",str1)
for item in poly_str:
str1 = str1.replace(item,'')
poly_str = ''.join(poly_str)
label_poly_list = re.findall("<Text>.*?</Text>",poly_str)
coord_poly_list = re.findall("<Points>.*?</Points>",poly_str)
poly_label_out = []
for item_poly in label_poly_list:
poly_label_out.append(item_poly[6:-7])
coord_poly_out = []
for coord_polys in coord_poly_list:
coord_poly_X = re.findall("X=.*?Y",coord_polys)
coord_poly_X_int = np.array([int(float(XX[3:-2])) for XX in coord_poly_X]).reshape(-1,1)
coord_poly_Y= re.findall("Y=.*?/>",coord_polys)
coord_poly_Y_int = np.array([int(float(YY[3:-3])) for YY in coord_poly_Y]).reshape(-1,1)
points_poly = np.hstack((coord_poly_X_int,coord_poly_Y_int))
coord_poly_out.append(points_poly)
label_list=re.findall("<Text>.*?</Text>",str1)
position_list=re.findall("<Extent.*?/>",str1)
angle_list=re.findall("<Data.*?>",str1)
# 取出文本框标签
label_list_new = []
for label in label_list:
result_label = label[6:-7]
label_list_new.append(result_label)
# 取出文本框旋转角度
angle_list_new = []
for angle in angle_list:
result_angle = float(angle.split('"')[-2])
angle_list_new.append(result_angle)
#取出左上角点 w h
position_list_new = []
for position in position_list:
X, Y, Width, Height = get_coord_range(position)
position_list_new.append([X, Y, Width, Height])
# 得到标注框坐标
point_list = []
for i in range(len(position_list_new)):
value = position_list_new[i]
angle = angle_list_new[i]
x1 = int(float(value[0])) # 左上
y1 = int(float(value[1]))
x2 = int(float(value[0]) + float(value[2])) # 右上
y2 = int(float(value[1]))
x3 = int(float(value[0]) + float(value[2])) # 右下
y3 = int(float(value[1]) + float(value[3]))
x4 = int(float(value[0])) # 左下
y4 = int(float(value[1]) + float(value[3]))
if not angle == 0:
angle = -angle
center_x = (x1 + x2 + x3 + x4) / 4
center_y = (y1 + y2 + y3 + y4) / 4
x1, y1 = get_new_coord([center_x, center_y], [x1, y1], angle)
x2, y2 = get_new_coord([center_x, center_y], [x2, y2], angle)
x3, y3 = get_new_coord([center_x, center_y], [x3, y3], angle)
x4, y4 = get_new_coord([center_x, center_y], [x4, y4], angle)
if int(x1) < 0:
x1 = 0
if (int(y1) < 0):
y1 = 0
if (int(y2) < 0):
y2 = 0
if (int(x4) < 0):
x4 = 0
point = [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
point_list.append(point)
return label_list_new,point_list,coord_poly_out,poly_label_out
def get_yolov3_box(coord,img):
## coord = [x1,y1,x2,y2]
h_factor, w_factor = img.shape[0:2]
box_ratio = []
ratio = ((coord[0]+coord[2])/2)/w_factor
box_ratio.append(ratio)
ratio = ((coord[1]+coord[3])/2)/h_factor
box_ratio.append(ratio)
ratio = (coord[2]-coord[0])/w_factor
box_ratio.append(ratio)
ratio = (coord[3]-coord[1])/h_factor
box_ratio.append(ratio)
return box_ratio
if __name__=='__main__':
import glob
import traceback
import cv2
import shutil
path=r'/src/notebook/detect_id/detect'
target_path_root = r'/src/notebook/detect_id/PyTorch-YOLOv3-master'
data_name = 'SFZ'
train_type = ['ID_Z','ID_F']
if(not os.path.exists(os.path.join(target_path_root,'data',data_name))):
os.mkdir(os.path.join(target_path_root,'data',data_name))
if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'images'))):
os.mkdir(os.path.join(target_path_root, 'data', data_name,'images'))
if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'results'))):
os.mkdir(os.path.join(target_path_root, 'data', data_name,'results'))
if (not os.path.exists(os.path.join(target_path_root, 'data', data_name,'labels'))):
os.mkdir(os.path.join(target_path_root, 'data', data_name,'labels'))
fid_train = open(os.path.join(target_path_root, 'data', data_name,'train.txt'),'w+',encoding='utf-8')
fid_val = open(os.path.join(target_path_root, 'data', data_name,'valid.txt'),'w+',encoding='utf-8')
fid_type = open(os.path.join(target_path_root, 'data', data_name,'classes.names'),'w+',encoding='utf-8')
for item in train_type:
fid_type.write(item+'\n')
fid_type.close()
jpg_files=glob.glob(os.path.join(path,'*.jpg'))
bar=tqdm(total=len(jpg_files))
for file in jpg_files:
img = cv2.imread(file)
h,w = img.shape[0:2]
img_name = file.split('/')[-1]
bar.update(1)
try:
xml_file = file.replace('.jpg','_data.xml')
label_list, position_list,coord_poly_out,poly_label_out = GetItemLocation(xml_file)
if(len(poly_label_out)==0):
print(file)
continue
fid_box = open(os.path.join(target_path_root, 'data', data_name,'labels',img_name.split('.')[0]+'.txt'),'w+',encoding='utf-8')
for ii in range(len(poly_label_out)):
id_label = int(poly_label_out[ii])-1
list_coord = [np.min(coord_poly_out[ii][:,0]),np.min(coord_poly_out[ii][:,1]),
np.max(coord_poly_out[ii][:, 0]), np.max(coord_poly_out[ii][:, 1])
]
list_coord = [np.max([list_coord[0],0]),np.max([list_coord[1],0]),
np.min([list_coord[2], w-1]), np.min([list_coord[3], h-1])]
img = cv2.rectangle(img,(int(list_coord[0]),int(list_coord[1])),(int(list_coord[2]),int(list_coord[3])),(0,0,255))
coord_ratio = get_yolov3_box(list_coord, img)
coord_ratio = [str(x) for x in coord_ratio]
coord_str =' '+ ' '.join(coord_ratio)
fid_box.write(str(id_label)+coord_str+'\n')
fid_train.write(os.path.join('data',data_name,'images',img_name)+'\n')
fid_val.write(os.path.join('data',data_name,'images',img_name)+'\n')
shutil.copy(file,os.path.join(target_path_root,'data',data_name,'images'))
cv2.imwrite(os.path.join(target_path_root,'data',data_name,'results',img_name),img)
fid_box.close()
# 左上 左下 右下 右上
IMIMkey = file.split('\\')[-1].replace('.jpg', '')
with open(os.path.join(path, IMIMkey + '.txt'), 'w+', encoding='utf-8') as fid:
for i, label in enumerate(label_list):
position = position_list[i]
flag = 1
if flag == 1:
fid.writelines(str(position[0][0]) + ',' + str(position[0][1]) + ',' +
str(position[1][0]) + ',' + str(position[1][1]) + ',' +
str(position[2][0]) + ',' + str(position[2][1]) + ',' +
str(position[3][0]) + ',' + str(position[3][1]) +
',' + label +'\n')
except:
print(file)
traceback.print_exc()
fid_train.close()
fid_val.close()