数据集标签转换小工具
VOC系列的数据集,如VOC2007, VOC2012等,标签是用xml文件给出的,而yolo算法中使用的标签是[class_id, x1, y1, x2, y2]类型的数据,因此在使用VOC数据集在yolo算法中训练时标签需要转换
代码
首先我们需要确保自己的环境下有numpy
和xml
这两个包(没有的话pip install 一下),然后创建一个文件命名为label_trans.py的文件并把下面的代码粘贴上去,再根据自己的数据集路径与需要训练的文件集修改下面的两个路径,最后运行python label_trans.py
就可以在本地路径下生成对应的标签文件。
import os
from pathlib import Path
import os.path as osp
import numpy as np
import xml.etree.ElementTree as ET
class_to_ind = {'aeroplane':0, 'bicycle':1, 'bird':2, 'boat':3,
'bottle':4, 'bus':5, 'car':6, 'cat':7, 'chair':8,
'cow':9, 'diningtable':10, 'dog':11, 'horse':12,
'motorbike':13, 'person':14, 'pottedplant':15,
'sheep':16, 'sofa':17, 'train':18, 'tvmonitor':19}
def voc_to_yolo(path, keep_difficult=False):
res = []
target = ET.parse(path).getroot()
size = target.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
label_idx = class_to_ind[name]
bndbox = []
bndbox.append(label_idx)
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text)
cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
res += [bndbox] # [class_id, x1, y1, x2, y2]
return np.array(res, dtype=np.float32)
# 使用方法,运行:python label_trans.py
DATA_ROOT = 'D:/programFiles/yolov5/datasets/VOCdevkit2007' # VOC数据集的目录(按照本地路径修改)
if __name__ == '__main__':
path = osp.join(Path(DATA_ROOT), 'VOC2007')
annopath = osp.join(path, 'Annotations', '%s.xml')
imgpath = osp.join(path, 'JPEGImages', '%s.jpg')
name = 'trainval.txt' # 指定需要训练哪一些文件(按照具体需求修改)
labels_path = 'labels/'
if not osp.exists(labels_path):
os.makedirs(labels_path)
# yolo标签格式写入txt文件
for line in open(osp.join(path, 'ImageSets', 'Main', name)):
line = line.strip()
boxes = voc_to_yolo(annopath % line)
np.savetxt(osp.join(labels_path, line + '.txt'), boxes, fmt='%.6f')