Yolov3 tensorflow2.0 源代码:https://github.com/YunYang1994/TensorFlow2.0-Examples/tree/master/4-Object_Detection/YOLOV3
Train your own dataset :
数据集: 图片 + XML文件
训练前要把 xml 转换为 txt 文件,txt文件如以下:
xxx/xxx.jpg 18.19,6.32,424.13,421.83,20 323.86,2.65,640.0,421.94,20
xxx/xxx.jpg 48,240,195,371,11 8,12,352,498,14
# image_path x_min, y_min, x_max, y_max, class_id x_min, y_min ,..., class_id
# make sure that x_max < width and y_max < height
所以我们需要 xml_to_txt 的脚本:
"""
需要修改的地方 :
1、你自己的类别 CLASSES
2、数据集路劲 data_path
3、第12行,生成文件的保存路径及名称
"""
import xml.etree.ElementTree as ET
import os
CLASSES = ['apple', 'pear', 'tomato', 'eggplant']
def convert_xml_annotation(data_path, classes):
xml_dir = []
for xml in os.listdir(data_path):
if xml.endswith('.xml'):
xml_dir.append(xml)
print("Total xml files : ", len(xml_dir))
with open("D:/TensorFlow2.0-Examples/4-Object_Detection/YOLOV3/data/dataset/fruits_train.txt", 'w') as f:
for i in range(len(xml_dir)):
tree = ET.parse(data_path + xml_dir[i])
root = tree.getroot()
# image path
filename = root.find('filename').text
image_path = data_path + filename
annotation = image_path
# coordinates of label : xmin ymin xmax ymax
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
bbox = obj.find('bndbox')
xmin = bbox.find('xmin').text.strip()
xmax = bbox.find('xmax').text.strip()
ymin = bbox.find('ymin').text.strip()
ymax = bbox.find('ymax').text.strip()
annotation += ' ' + ','.join([xmin, ymin, xmax, ymax,str(cls_id)])
print(annotation)
f.write(annotation + "\n")
convert_xml_annotation("D:/SoftWare/easydl2labelImg-master/Data/Fruits/", CLASSES)
处理之后,图片不需要再进行Resize,就可以直接用来训练。因为在项目的实现中,作者已经写好了相应的img_preprocess函数。