基于深度学习模型的自动标注,解放双手。
工欲善,必先利其器。玩过工程上的深度学习的同学都知道,工业数据很珍贵,且需要自己打标签。时间就是生命,结合网上的资料写了一个小脚本,实现自动标注的功能,后期人为去校验就可以了。idea is cheap,show me in the code!!!
这个demo需要有python和目标检测的知识储备,不需要太多,一点点即可。
1.首先导入需要的包。
from xml.etree import ElementTree as ET
from xml.dom import minidom
2.生成xml象。
def create_xml_tree(sources, image_name, h, w):
imgdir = sources.split('/')[-2]
annotation = ET.Element('annotation')
folder = ET.SubElement(annotation, 'folder')
folder.text = (imgdir)
filename = ET.SubElement(annotation, 'filename')
filename.text = image_name
path = ET.SubElement(annotation, 'path')
path.text = sources
source = ET.SubElement(annotation, 'source')
database = ET.SubElement(source, 'database')
database.text = 'Unknown'
size = ET.SubElement(annotation, 'size')
width = ET.SubElement(size, 'width')
width.text = str(w)
height = ET.SubElement(size, 'height')
height.text = str(h)
depth = ET.SubElement(size, 'depth')
depth.text = '3'
segmented = ET.SubElement(annotation, 'segmented')
segmented.text = '0'
return annotation
3.写入object信息。
def create_object(root, result_list):
"""result_list = [{
"xmin":x_min,
"ymin":y_min,
"xmax":x_max,
"ymax":y_max,
"name":cls
}]"""
for data in result_list:
object = ET.SubElement(root, 'object')
name = ET.SubElement(_object, 'name')
name.text = str(data["name"])
pose = ET.SubElement(_object, 'pose')
pose.text = 'Unspecified'
truncated = ET.SubElement(_object, 'truncated')
truncated.text = '0'
difficult = ET.SubElement(_object, 'difficult')
difficult.text = '0'
bndbox = ET.SubElement(_object, 'bndbox')
xmin = ET.SubElement(bndbox, 'xmin')
xmin.text = '%s' % data["xmin"]
ymin = ET.SubElement(bndbox, 'ymin')
ymin.text = '%s' % data["ymin"]
xmax = ET.SubElement(bndbox, 'xmax')
xmax.text = '%s' % data["xmax"]
ymax = ET.SubElement(bndbox, 'ymax')
ymax.text = '%s' % data["ymax"]
4.格式化xml对象
def prettify_xml(elem):
"""将 ElementTree 对象转换为格式化的 XML 字符串"""
rough_string = ET.tostring(elem, 'utf-8')
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent=" ")
5.写入xml文件中
with open(file_path, "w", encoding="utf-8") as f:
f.write(pretty_xml)
注意的是,需要目标检测的输出格式修改为result_list那样的格式,那么这个脚本可以直接调用。完整的代码如下。 tool.py
from xml.etree import ElementTree as ET
from xml.dom import minidom
def create_tree(sources, image_name, h, w):
imgdir = sources.split('/')[-2]
annotation = ET.Element('annotation')
folder = ET.SubElement(annotation, 'folder')
folder.text = (imgdir)
filename = ET.SubElement(annotation, 'filename')
filename.text = image_name
path = ET.SubElement(annotation, 'path')
path.text = sources
source = ET.SubElement(annotation, 'source')
database = ET.SubElement(source, 'database')
database.text = 'Unknown'
size = ET.SubElement(annotation, 'size')
width = ET.SubElement(size, 'width')
width.text = str(w)
height = ET.SubElement(size, 'height')
height.text = str(h)
depth = ET.SubElement(size, 'depth')
depth.text = '3'
segmented = ET.SubElement(annotation, 'segmented')
segmented.text = '0'
return annotation
def create_object(root, result_list):
for data in result_list:
xml_object = ET.SubElement(root, 'object')
name = ET.SubElement(xml_object , 'name')
name.text = str(data["name"])
pose = ET.SubElement(xml_object , 'pose')
pose.text = 'Unspecified'
truncated = ET.SubElement(xml_object , 'truncated')
truncated.text = '0'
difficult = ET.SubElement(xml_object , 'difficult')
difficult.text = '0'
bndbox = ET.SubElement(xml_object , 'bndbox')
xmin = ET.SubElement(bndbox, 'xmin')
xmin.text = '%s' % data["xmin"]
ymin = ET.SubElement(bndbox, 'ymin')
ymin.text = '%s' % data["ymin"]
xmax = ET.SubElement(bndbox, 'xmax')
xmax.text = '%s' % data["xmax"]
ymax = ET.SubElement(bndbox, 'ymax')
ymax.text = '%s' % data["ymax"]
def prettify_xml(elem):
"""将 ElementTree 对象转换为格式化的 XML 字符串"""
rough_string = ET.tostring(elem, 'utf-8')
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent=" ")
调用脚本
from ultralytics import YOLO
from util import tool
import os
from xml.etree import ElementTree as ET
# model里存放下载好的yolov10n.pt的路径信息
model = YOLO(model='yourmodel')
# 这里的图片是源码里自带的,可以替换成自己的图片
name = model.names
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
dir_path = "yourdir"
for root, dirs, files in os.walk(dir_path):
for file in files:
if file.lower().endswith(image_extensions):
image_path = os.path.join(root, file)
result_list = []
results = model(image_path)
for r in results:
img_shape = r.orig_shape
for i in range(len(r.boxes.cls)):
cls = name[int((r.boxes.cls)[i])]
coords = r.boxes[i].xyxy.cpu().numpy()
for coord in coords:
# 获取左上角坐标和宽高
x_min, y_min, x_max, y_max = coord
result_data = {
"xmin":x_min,
"ymin":y_min,
"xmax":x_max,
"ymax":y_max,
"name":cls
}
print(result_data)
result_list.append(result_data)
annotation = tool.create_tree(image_path, file, img_shape[0], img_shape[1])
tool.create_object(annotation, result_list)
tree = ET.ElementTree(annotation)
# 获取根元素
pretty_xml = tool.prettify_xml(tree.getroot())
file_path='{}.xml'.format(image_path.strip('.jpg'))
print(file_path)
# 将格式化后的 XML 写入指定的文件夹中
with open(file_path, "w", encoding="utf-8") as f:
f.write(pretty_xml)
# print(f"XML 文件已规范化并保存为 {file_path}")```