深度学习——yolov5的txt和xml互转

在学习工作的过程中,有时会需要自己新建数据集,向训练数据中添加新的数据,存在已有模型对新数据进行检测,得到yolov5对应的txt文件,之后转成xml,使用标注工具对数据进行校正。后续将xml转成yolov5训练使用的txt格式。
以下是使用多进程和多线程两种方式实现的txt和xml互转的代码。

20240320更新使用多进程调用多线程,处理速度更快。

20240320更新

使用多进程调用多线程来处理数据,大量数据处理时节约时间

线程数没必要太大,太大反而会因为线程的切换频繁导致速度变慢。
进程数可以根据自己的设备设置,数据量超多的时候进程数设置的越多处理的越快。

xml->txt

# -*- coding: UTF-8 -*-
"""
@Project :yolov5_relu_fire_smoke_v1.3 
@IDE     :PyCharm 
@Author  :
@Date    :2024/2/22 15:58

将xml转为yolo训练使用的txt格式
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo训练使用的是[xn,yn,wn,hn]坐标格式
data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import xml.etree.ElementTree as ET
import cv2
import numpy as np


def decodeVocAnnotation(voc_xml_path, class_index_dict):
    """
    voc数据集格式的文件解析,将一个文件解析成一个list,
    使用空格间隔不同对象
    注意:返回的类别不是整型,而是字符串的类别名称
    注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图
    :param voc_xml_path: xml路径
    :param class_index_dict: 类别字典
    :return:
    """
    assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml"

    xml_file = open(voc_xml_path, 'r', encoding='utf-8')
    # 打开xml文件,并返回根节点
    root = ET.ElementTree().parse(xml_file)

    # 定义一个列表,专门保存目标
    information = []

    # 查找root节点下所有目标信息
    for obj in root.iter('object'):
        # 目标的名称
        name = obj.find('name').text
        # 目标的bbox坐标,一般voc是保存的corner格式的bbox
        box = obj.find('bndbox')
        xmin = box.find('xmin').text
        ymin = box.find('ymin').text
        xmax = box.find('xmax').text
        ymax = box.find('ymax').text

        # 添加一个目标的信息
        # NOTE:返回值的list
        information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax)))

    xml_file.close()
    return information


def xyxy2xywh(matrix):
    """

    :param matrix: np矩阵, x1, y1, x2, y2
    :return:
    """
    # 确保输入矩阵的形状为 (n, 4),其中 n 是矩阵中矩形的数量
    if matrix.ndim < 2 or matrix.shape[1] != 4:
        raise ValueError("Input matrix must have shape (n, 4)")

    # 计算中心点坐标
    center_x = (matrix[:, 0] + matrix[:, 2]) / 2
    center_y = (matrix[:, 1] + matrix[:, 3]) / 2

    # 计算宽度和高度
    width = np.abs(matrix[:, 2] - matrix[:, 0])
    height = np.abs(matrix[:, 3] - matrix[:, 1])

    # 返回结果,组合为 (center_x, center_y, width, height) 形式
    return np.column_stack((center_x, center_y, width, height))


def run_thread(root, file, image_root, save_txt_root, class_index_dict):
    """

    :param root:
    :param file: 图片路径
    :param image_root:
    :param save_txt_root:
    :param class_index_dict:
    :return:
    """
    image_name, suffix = os.path.splitext(file)
    image_path = os.path.join(root, file)
    xml_path = image_path.replace("images", "annotations").replace(suffix, ".xml")
    txt_path = image_path.replace(image_root, save_txt_root).replace(suffix, ".txt")

    try:
        if os.path.exists(xml_path):
            # cls_index, x1, y1, x2, y2
            bbox = decodeVocAnnotation(xml_path, class_index_dict)
            bbox = np.array(bbox, dtype=np.float32)
        else:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            return

        if len(bbox) == 0:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            return

        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31m{image_path} is None\033[0m")
            return
        else:
            print(f"\r\033[32m{image_path}\033[0m", end='')

        imh, imw = image.shape[:2]

        # # 画框,视为了检查框是否正确
        # for cls_id, xmin, ymin, xmax, ymax in np.array(bbox.copy(), dtype=np.int32):
        #     cv2.putText(image, text=f"{cls_id}", org=(xmin, ymin),
        #                 fontScale=2, fontFace=1, color=(0, 255, 0), thickness=1)
        #     cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=(0, 255, 0), thickness=2)
        # cv2.imwrite(os.path.join(data_root, "temp", image_file), image)

        # 坐标转换 xyxy -> xywh
        bbox[:, 1:] = xyxy2xywh(bbox[:, 1:])
        # 归一化
        bbox[..., [1, 3]] /= imw
        bbox[..., [2, 4]] /= imh

        os.makedirs(os.path.dirname(txt_path), exist_ok=True)
        # 保存结果
        with open(txt_path, "w", encoding="utf-8") as wFile:
            for cls_id, x, y, w, h in bbox:
                wFile.write(f"{int(cls_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")

    except Exception as e:
        print(e)
        print(f"\n\033[31m{image_path}\033[0m")
        return


def run_process(root_file, image_root, save_txt_root, class_index_dict):
    """

    :param root_file: [(root, file), ...] 因为进程的创建花费时间长,所以一个进程处理多个图片
    :param image_root:
    :param save_txt_root:
    :param class_index_dict:
    :return:
    """
    # 使用线程池控制程序执行
    with futures.ThreadPoolExecutor(max_workers=5) as executor:
        for root, file in root_file:
            executor.submit(run_thread, *(root, file, image_root, save_txt_root, class_index_dict))


if __name__ == '__main__':
    # 需要已知目标的名称和类别索引
    class_index_dict = {
        "fire": 0,
        "smoke": 1,
    }

    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMFireSmoke"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMCandle"
    # data_root = r"Z:\Datasets\FireSmoke_v4"
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\candle-test"
    data_root = os.path.abspath(data_root)

    image_root = os.path.join(data_root, "images")

    # save_txt_root = r""
    save_txt_root = os.path.join(data_root, "labels")

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数

    print(f"max_workers:{max_workers}")

    # 一个进程处理多少图片
    max_file_num = 2000
    # 保存root和file的list
    root_file_list: List[Tuple] = list()

    # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
    pool = multiprocessing.Pool(processes=max_workers)
    # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
    for root, _, files in os.walk(image_root):
        # 需要排除的目录
        if any(exclude_dir in root for exclude_dir in exclude_dirs):
            continue

        for file in files:
            root_file_list.append((root, file))

            if len(root_file_list) > max_file_num:
                # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                pool.apply_async(run_process, (deepcopy(root_file_list), image_root, save_txt_root, class_index_dict))
                # 清除列表中的存储
                root_file_list.clear()

    else:
        # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
        # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
        # 启动一个进程,开始处理当前list中的信息
        pool.apply_async(run_process, (deepcopy(root_file_list), image_root, save_txt_root, class_index_dict))
        # 清除列表中的存储
        root_file_list.clear()

    # 关闭进程池
    pool.close()
    # 等待所有子进程执行结束
    pool.join()

    print("\nFinish ...")

txt->xml

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :TestCode 
@IDE     :PyCharm 
@Author  :
@Date    :2023/7/21 17:15

yolov5检测出来的目标结果,转成xml
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo检测结果保存使用的是[xn,yn,wn,hn]坐标格式
如果保存txt保存了置信度则txt每一行是:[class_index, xn, yn, wn, hn, conf]
使用线程池实现

data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import cv2
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom


def create_voc_xml(image_folder, image_filename, width: int, height: int, labels,
                   save_path, class_name_dict, conf_thresh_dict=None):
    """

    :param image_folder: 图片的相对路径
    :param image_filename: 000001.jpg
    :param width: 图片宽
    :param height: 图片高
    :param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...]
    :param save_path: 保存xml的根目录
    :param class_name_dict: cls_index:cls_name,根据index获取正确的类别name
    :param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度
    :return:
    """
    # 创建 XML 文件的根元素
    root = ET.Element("annotation")

    # 添加图片信息
    folder = ET.SubElement(root, "folder")
    folder.text = str(image_folder)
    # 图片名字
    filename = ET.SubElement(root, "filename")
    filename.text = os.path.join(image_filename)
    # 图片大小
    size = ET.SubElement(root, "size")
    width_element = ET.SubElement(size, "width")
    width_element.text = str(width)
    height_element = ET.SubElement(size, "height")
    height_element.text = str(height)
    depth = ET.SubElement(size, "depth")  # 通道数
    depth.text = "3"

    # 添加目标框信息
    for label in labels:
        # 如果该参数设置为None,表示保存的txt没有None
        if conf_thresh_dict is None:
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32)
        else:
            class_index, x1, y1, x2, y2, conf = label
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32)

            # 根据置信度过滤是否保存项
            if conf < conf_thresh_dict[class_index]:
                continue

        obj = ET.SubElement(root, "object")

        name = ET.SubElement(obj, "name")
        name.text = class_name_dict[int(class_index)]

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(x1)

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(y1)

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(x2)

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(y2)

    # 创建 XML 文件并保存
    xml_str = ET.tostring(root, encoding="utf-8")
    xml_str = minidom.parseString(xml_str)
    # 设置缩进为4个空格,xml可读性提高
    pretty_xml = xml_str.toprettyxml(indent=" " * 4)

    save_path = os.path.join(save_path, f"{os.path.splitext(image_filename)[0]}.xml")
    os.makedirs((os.path.dirname(save_path)), exist_ok=True)
    with open(save_path, "w") as xmlFile:
        xmlFile.write(pretty_xml)


def run_thread(root, image_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root: ..\images\train
    @param image_file: 0000000.jpg
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    # 获取图片的名称和后缀
    image_name, suffix = os.path.splitext(image_file)
    # 图片路径
    image_path = os.path.join(root, image_file)
    # 设置捕捉异常,防止因为异常导致的代码停止运行
    try:
        # 读图
        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31mError {image_path}\033[0m")
            return

        # 图片的宽高
        imh, imw = image.shape[:2]

        # txt路径
        txt_file = image_path.replace(image_root, txt_root).replace(suffix, ".txt")
        if not os.path.exists(txt_file):
            return

        # class_index xn yn wn hn conf
        labels = np.loadtxt(txt_file, dtype=np.float32)
        # 空txt跳过
        if len(labels) == 0:
            return

        # 确包所有矩阵维度都是2维,方便后续处理
        if labels.ndim == 1:
            labels = np.array([labels])

        # xywhn -> xywh
        labels[:, [1, 3]] = labels[:, [1, 3]] * imw
        labels[:, [2, 4]] = labels[:, [2, 4]] * imh
        center = labels[:, 1:5].copy()

        # xywh - > xyxy
        corner = np.zeros_like(center)
        corner[:, 0] = center[:, 0] - center[:, 2] / 2  # xmin = x - w / 2
        corner[:, 1] = center[:, 1] - center[:, 3] / 2  # ymin = y - h / 2
        corner[:, 2] = center[:, 0] + center[:, 2] / 2  # xmax = x + w / 2
        corner[:, 3] = center[:, 1] + center[:, 3] / 2  # ymax = y + h / 2
        # np.float32
        labels[:, 1:5] = corner[:, :]

        # 创建xml
        create_voc_xml(root.replace(image_root + os.sep, ""),  # Z:\FireData\images\train -> train
                       image_filename=image_file,
                       width=imw,
                       height=imh,
                       labels=labels,
                       save_path=root.replace(image_root, save_root),
                       class_name_dict=class_name_dict,
                       conf_thresh_dict=conf_thresh_dict)
        # 处理完成后打印信息,要不不知道执行到哪里了
        print(f"\r{image_path}", end='')

    except Exception as e:
        print(f"{image_path} \n{e}\n\n")


def run_process(root_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root_file: [(..\images\train, 0000000.jpg), ...]
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    # 使用线程池控制程序执行
    with futures.ThreadPoolExecutor(max_workers=5) as executor:
        for root, image_file in root_file:
            # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
            executor.submit(run_thread,
                            *(root, image_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict))


if __name__ == '__main__':
    # 类别字典
    class_name_dict = {
        0: "fire",
        1: "smoke"
    }
    # 置信度阈值,不同类别设置不同的阈值,
    CONF_THRESH_DICT = dict()  # 如果该参数设置为None表示txt没有保存conf这一项
    # CONF_THRESH_DICT.update(
    #     {
    #         0: 0.2,
    #         1: 0.2
    #     }
    # )
    if len(CONF_THRESH_DICT) != 0:
        assert class_name_dict.keys() == CONF_THRESH_DICT.keys(), "class_name_dict.keys() != CONF_THRESH_DICT.keys()."
    else:
        CONF_THRESH_DICT = None

    # 数据集根目录
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    data_root = os.path.abspath(data_root)

    # 指定的子目录
    sub_dir = r""
    # sub_dir = r"\train\fire_smoke"

    # 要保证这三个的目录结构是一致的
    # 保存xml的根路径 save_root\annotations\...
    if sub_dir == '':
        save_root = os.path.join(data_root, "annotations")
        # txt路径
        txt_root = os.path.join(data_root, "labels")  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images")

    else:
        save_root = os.path.join(data_root, "annotations", sub_dir)
        # txt路径
        txt_root = os.path.join(data_root, "labels", sub_dir)  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images", sub_dir)

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数
    print(f"max_workers:{max_workers}")

    # 一个进程处理多少图片
    max_file_num = 2000
    # 保存root和file的list
    root_file_list: List[Tuple] = list()

    # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
    pool = multiprocessing.Pool(processes=max_workers)
    # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
    for root, _, files in os.walk(image_root):
        # 需要排除的目录
        if any(exclude_dir in root for exclude_dir in exclude_dirs):
            continue

        for file in files:
            root_file_list.append((root, file))

            if len(root_file_list) > max_file_num:
                # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                pool.apply_async(run_process,
                                 (deepcopy(root_file_list), save_root, image_root, txt_root,
                                  class_name_dict, CONF_THRESH_DICT))
                # 清除列表中的存储
                root_file_list.clear()

    else:
        # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
        # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
        # 启动一个进程,开始处理当前list中的信息
        pool.apply_async(run_process, (deepcopy(root_file_list), save_root, image_root, txt_root,
                                       class_name_dict, CONF_THRESH_DICT))
        # 清除列表中的存储
        root_file_list.clear()

    # 关闭进程池
    pool.close()
    # 等待所有子进程执行结束
    pool.join()

    print("\nFinish ...")

20240317

分别使用多线程和多进程处理数据

xml->txt

# -*- coding: UTF-8 -*-
"""
@Project :yolov5_relu_fire_smoke_v1.3 
@IDE     :PyCharm 
@Author  :mufeng
@Date    :2024/2/22 15:58

将xml转为yolo训练使用的txt格式
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo训练使用的是[xn,yn,wn,hn]坐标格式
data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import xml.etree.ElementTree as ET
import cv2
import numpy as np


def decodeVocAnnotation(voc_xml_path, class_index_dict):
    """
    voc数据集格式的文件解析,将一个文件解析成一个list,
    使用空格间隔不同对象
    注意:返回的类别不是整型,而是字符串的类别名称
    注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图
    :param voc_xml_path: xml路径
    :param class_index_dict: 类别字典
    :return:
    """
    assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml"

    xml_file = open(voc_xml_path, 'r', encoding='utf-8')
    # 打开xml文件,并返回根节点
    root = ET.ElementTree().parse(xml_file)

    # 定义一个列表,专门保存目标
    information = []

    # 查找root节点下所有目标信息
    for obj in root.iter('object'):
        # 目标的名称
        name = obj.find('name').text
        # 目标的bbox坐标,一般voc是保存的corner格式的bbox
        box = obj.find('bndbox')
        xmin = box.find('xmin').text
        ymin = box.find('ymin').text
        xmax = box.find('xmax').text
        ymax = box.find('ymax').text

        # 添加一个目标的信息
        # NOTE:返回值的list
        information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax)))

    xml_file.close()
    return information


def xyxy2xywh(matrix):
    """

    :param matrix: np矩阵, x1, y1, x2, y2
    :return:
    """
    # 确保输入矩阵的形状为 (n, 4),其中 n 是矩阵中矩形的数量
    if matrix.ndim < 2 or matrix.shape[1] != 4:
        raise ValueError("Input matrix must have shape (n, 4)")

    # 计算中心点坐标
    center_x = (matrix[:, 0] + matrix[:, 2]) / 2
    center_y = (matrix[:, 1] + matrix[:, 3]) / 2

    # 计算宽度和高度
    width = np.abs(matrix[:, 2] - matrix[:, 0])
    height = np.abs(matrix[:, 3] - matrix[:, 1])

    # 返回结果,组合为 (center_x, center_y, width, height) 形式
    return np.column_stack((center_x, center_y, width, height))


def run_thread(root, file, class_index_dict):
    """

    :param root:
    :param file: 图片路径
    :param class_index_dict:
    :return:
    """
    image_name, suffix = os.path.splitext(file)
    image_path = os.path.join(root, file)
    xml_path = image_path.replace("images", "annotations").replace(suffix, ".xml")
    txt_path = image_path.replace("images", "labels").replace(suffix, ".txt")

    if os.path.exists(xml_path):
        # cls_index, x1, y1, x2, y2
        bbox = decodeVocAnnotation(xml_path, class_index_dict)
        bbox = np.array(bbox, dtype=np.float32)
    else:
        bbox = np.zeros(shape=(0, 5), dtype=np.float32)
        return

    if len(bbox) == 0:
        bbox = np.zeros(shape=(0, 5), dtype=np.float32)
        return

    image = cv2.imread(image_path)
    if image is None:
        print(f"\n\033[31m{image_path} is None\033[0m")
        return
    else:
        print(f"\r\033[32m{image_path}\033[0m", end='')

    imh, imw = image.shape[:2]

    # # 画框,视为了检查框是否正确
    # for cls_id, xmin, ymin, xmax, ymax in np.array(bbox.copy(), dtype=np.int32):
    #     cv2.putText(image, text=f"{cls_id}", org=(xmin, ymin),
    #                 fontScale=2, fontFace=1, color=(0, 255, 0), thickness=1)
    #     cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=(0, 255, 0), thickness=2)
    # cv2.imwrite(os.path.join(data_root, "temp", image_file), image)

    # 坐标转换 xyxy -> xywh
    bbox[:, 1:] = xyxy2xywh(bbox[:, 1:])
    # 归一化
    bbox[..., [1, 3]] /= imw
    bbox[..., [2, 4]] /= imh

    os.makedirs(os.path.dirname(txt_path), exist_ok=True)
    # 保存结果
    with open(txt_path, "w", encoding="utf-8") as wFile:
        for cls_id, x, y, w, h in bbox:
            wFile.write(f"{int(cls_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")


def run_process(root_file, class_index_dict):
    """

    :param root_file: [(root, file), ...] 因为进程的创建花费时间长,所以一个进程处理多个图片
    :param class_index_dict:
    :return:
    """
    for root, file in root_file:
        image_name, suffix = os.path.splitext(file)

        image_path = os.path.join(root, file)
        xml_path = image_path.replace("images", "annotations").replace(suffix, ".xml")
        txt_path = image_path.replace("images", "labels").replace(suffix, ".txt")

        if os.path.exists(xml_path):
            # cls_index, x1, y1, x2, y2
            bbox = decodeVocAnnotation(xml_path, class_index_dict)
            bbox = np.array(bbox, dtype=np.float32)
        else:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            continue

        if len(bbox) == 0:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            continue

        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31m{image_path} is None\033[0m")
            continue
        else:
            print(f"\r\033[32m{image_path}\033[0m", end='')

        imh, imw = image.shape[:2]

        # # 画框,视为了检查框是否正确
        # for cls_id, xmin, ymin, xmax, ymax in np.array(bbox.copy(), dtype=np.int32):
        #     cv2.putText(image, text=f"{cls_id}", org=(xmin, ymin),
        #                 fontScale=2, fontFace=1, color=(0, 255, 0), thickness=1)
        #     cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=(0, 255, 0), thickness=2)
        # cv2.imwrite(os.path.join(data_root, "temp", image_file), image)

        # 坐标转换 xyxy -> xywh
        bbox[:, 1:] = xyxy2xywh(bbox[:, 1:])
        # 归一化
        bbox[..., [1, 3]] /= imw
        bbox[..., [2, 4]] /= imh

        os.makedirs(os.path.dirname(txt_path), exist_ok=True)
        # 保存结果
        with open(txt_path, "w", encoding="utf-8") as wFile:
            for cls_id, x, y, w, h in bbox:
                wFile.write(f"{int(cls_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")


if __name__ == '__main__':
    # 需要已知目标的名称和类别索引
    class_index_dict = {
        "fire": 0,
        "smoke": 1,
    }

    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMFireSmoke"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMCandle"
    # data_root = r"Z:\Datasets\FireSmoke_v4"
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\candle-test"
    data_root = os.path.abspath(data_root)

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数
    # 使用的类型
    # run_type = "thread"  # 多线程
    run_type = "process"  # 多进程

    print(f"running use run_type={run_type}, max_workers:{max_workers}")

    if run_type == "thread":
        # 使用线程池控制程序执行
        with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            for root, _, files in os.walk(os.path.join(data_root, "images")):
                # 需要排除的目录
                if any(exclude_dir in root for exclude_dir in exclude_dirs):
                    continue

                for file in files:
                    # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
                    executor.submit(run_thread, *(root, file, class_index_dict))

        print("\nFinish ...")

    elif run_type == "process":
        # 一个进程处理多少图片
        max_file_num = 1000
        # 保存root和file的list
        root_file_list: List[Tuple] = list()

        # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
        pool = multiprocessing.Pool(processes=max_workers)

        # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
        for root, _, files in os.walk(os.path.join(data_root, "images")):
            # 需要排除的目录
            if any(exclude_dir in root for exclude_dir in exclude_dirs):
                continue

            for file in files:
                root_file_list.append((root, file))

                if len(root_file_list) > max_file_num:
                    # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                    pool.apply_async(run_process, (deepcopy(root_file_list), class_index_dict))
                    # 清除列表中的存储
                    root_file_list.clear()

        else:
            # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
            # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
            # 启动一个进程,开始处理当前list中的信息
            pool.apply_async(run_process, (deepcopy(root_file_list), class_index_dict))
            # 清除列表中的存储
            root_file_list.clear()

        # 关闭进程池
        pool.close()
        # 等待所有子进程执行结束
        pool.join()

        print("\nFinish ...")

    else:
        print("run_type should be thread or process.")

txt->xml

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :TestCode 
@IDE     :PyCharm 
@Author  :mufeng
@Date    :2023/7/21 17:15

yolov5检测出来的目标结果,转成xml
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo检测结果保存使用的是[xn,yn,wn,hn]坐标格式
如果保存txt保存了置信度则txt每一行是:[class_index, xn, yn, wn, hn, conf]
使用线程池实现

data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import cv2
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom


def create_voc_xml(image_folder, image_filename, width: int, height: int, labels,
                   save_path, class_name_dict, conf_thresh_dict=None):
    """

    :param image_folder: 图片的相对路径
    :param image_filename: 000001.jpg
    :param width: 图片宽
    :param height: 图片高
    :param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...]
    :param save_path: 保存xml的根目录
    :param class_name_dict: cls_index:cls_name,根据index获取正确的类别name
    :param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度
    :return:
    """
    # 创建 XML 文件的根元素
    root = ET.Element("annotation")

    # 添加图片信息
    folder = ET.SubElement(root, "folder")
    folder.text = str(image_folder)
    # 图片名字
    filename = ET.SubElement(root, "filename")
    filename.text = os.path.join(image_filename)
    # 图片大小
    size = ET.SubElement(root, "size")
    width_element = ET.SubElement(size, "width")
    width_element.text = str(width)
    height_element = ET.SubElement(size, "height")
    height_element.text = str(height)
    depth = ET.SubElement(size, "depth")  # 通道数
    depth.text = "3"

    # 添加目标框信息
    for label in labels:
        # 如果该参数设置为None,表示保存的txt没有None
        if conf_thresh_dict is None:
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32)
        else:
            class_index, x1, y1, x2, y2, conf = label
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32)

            # 根据置信度过滤是否保存项
            if conf < conf_thresh_dict[class_index]:
                continue

        obj = ET.SubElement(root, "object")

        name = ET.SubElement(obj, "name")
        name.text = class_name_dict[int(class_index)]

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(x1)

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(y1)

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(x2)

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(y2)

    # 创建 XML 文件并保存
    xml_str = ET.tostring(root, encoding="utf-8")
    xml_str = minidom.parseString(xml_str)
    # 设置缩进为4个空格,xml可读性提高
    pretty_xml = xml_str.toprettyxml(indent=" " * 4)

    save_path = os.path.join(save_path, f"{os.path.splitext(image_filename)[0]}.xml")
    os.makedirs((os.path.dirname(save_path)), exist_ok=True)
    with open(save_path, "w") as xmlFile:
        xmlFile.write(pretty_xml)


def run_thread(root, image_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root: ..\images\train
    @param image_file: 0000000.jpg
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    # 获取图片的名称和后缀
    image_name, suffix = os.path.splitext(image_file)
    # 图片路径
    image_path = os.path.join(root, image_file)
    # 设置捕捉异常,防止因为异常导致的代码停止运行
    try:
        # 读图
        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31mError {image_path}\033[0m")
            return

        # 图片的宽高
        imh, imw = image.shape[:2]

        # txt路径
        txt_file = image_path.replace(image_root, txt_root).replace(suffix, ".txt")
        if not os.path.exists(txt_file):
            return

        # class_index xn yn wn hn conf
        labels = np.loadtxt(txt_file, dtype=np.float32)
        # 空txt跳过
        if len(labels) == 0:
            return

        # 确包所有矩阵维度都是2维,方便后续处理
        if labels.ndim == 1:
            labels = np.array([labels])

        # xywhn -> xywh
        labels[:, [1, 3]] = labels[:, [1, 3]] * imw
        labels[:, [2, 4]] = labels[:, [2, 4]] * imh
        center = labels[:, 1:5].copy()

        # xywh - > xyxy
        corner = np.zeros_like(center)
        corner[:, 0] = center[:, 0] - center[:, 2] / 2  # xmin = x - w / 2
        corner[:, 1] = center[:, 1] - center[:, 3] / 2  # ymin = y - h / 2
        corner[:, 2] = center[:, 0] + center[:, 2] / 2  # xmax = x + w / 2
        corner[:, 3] = center[:, 1] + center[:, 3] / 2  # ymax = y + h / 2
        # np.float32
        labels[:, 1:5] = corner[:, :]

        # 创建xml
        create_voc_xml(root.replace(image_root + os.sep, ""),  # Z:\FireData\images\train -> train
                       image_filename=image_file,
                       width=imw,
                       height=imh,
                       labels=labels,
                       save_path=root.replace(image_root, save_root),
                       class_name_dict=class_name_dict,
                       conf_thresh_dict=conf_thresh_dict)
        # 处理完成后打印信息,要不不知道执行到哪里了
        print(f"\r{image_path}", end='')

    except Exception as e:
        print(f"{image_path} \n{e}\n\n")


def run_process(root_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root_file: [(..\images\train, 0000000.jpg), ...]
    @param image_file:
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    for root, image_file in root_file:
        # 获取图片的名称和后缀
        image_name, suffix = os.path.splitext(image_file)
        # 图片路径
        image_path = os.path.join(root, image_file)
        # 设置捕捉异常,防止因为异常导致的代码停止运行
        try:
            # 读图
            image = cv2.imread(image_path)
            if image is None:
                print(f"\n\033[31mError {image_path}\033[0m")
                continue

            # 图片的宽高
            imh, imw = image.shape[:2]

            # txt路径
            txt_file = image_path.replace(image_root, txt_root).replace(suffix, ".txt")
            if not os.path.exists(txt_file):
                continue

            # class_index xn yn wn hn conf
            labels = np.loadtxt(txt_file, dtype=np.float32)
            # 空txt跳过
            if len(labels) == 0:
                continue

            # 确包所有矩阵维度都是2维,方便后续处理
            if labels.ndim == 1:
                labels = np.array([labels])

            # xywhn -> xywh
            labels[:, [1, 3]] = labels[:, [1, 3]] * imw
            labels[:, [2, 4]] = labels[:, [2, 4]] * imh
            center = labels[:, 1:5].copy()

            # xywh - > xyxy
            corner = np.zeros_like(center)
            corner[:, 0] = center[:, 0] - center[:, 2] / 2  # xmin = x - w / 2
            corner[:, 1] = center[:, 1] - center[:, 3] / 2  # ymin = y - h / 2
            corner[:, 2] = center[:, 0] + center[:, 2] / 2  # xmax = x + w / 2
            corner[:, 3] = center[:, 1] + center[:, 3] / 2  # ymax = y + h / 2
            # np.float32
            labels[:, 1:5] = corner[:, :]

            # 创建xml
            create_voc_xml(root.replace(image_root + os.sep, ""),  # Z:\FireData\images\train -> train
                           image_filename=image_file,
                           width=imw,
                           height=imh,
                           labels=labels,
                           save_path=root.replace(image_root, save_root),
                           class_name_dict=class_name_dict,
                           conf_thresh_dict=conf_thresh_dict)
            # 处理完成后打印信息,要不不知道执行到哪里了
            print(f"\r{image_path}", end='')

        except Exception as e:
            print(f"{image_path} \n{e}\n\n")


if __name__ == '__main__':
    # 类别字典
    class_name_dict = {
        0: "fire",
        1: "smoke"
    }
    # 置信度阈值,不同类别设置不同的阈值,
    CONF_THRESH_DICT = None  # 如果该参数设置为None表示txt没有保存conf这一项
    # CONF_THRESH_DICT = {
    #     0: 0.2,
    #     1: 0.2
    # }
    if CONF_THRESH_DICT != None:
        assert class_name_dict.keys() == CONF_THRESH_DICT.keys(), "class_name_dict.keys() != CONF_THRESH_DICT.keys()."

    # 数据集根目录
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    data_root = os.path.abspath(data_root)

    # 指定的子目录
    sub_dir = r""
    # sub_dir = r"\train\fire_smoke"

    # 要保证这三个的目录结构是一致的
    # 保存xml的根路径 save_root\annotations\...
    if sub_dir == '':
        save_root = os.path.join(data_root, "annotations")
        # txt路径
        txt_root = os.path.join(data_root, "labels")  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images")

    else:
        save_root = os.path.join(data_root, "annotations", sub_dir)
        # txt路径
        txt_root = os.path.join(data_root, "labels", sub_dir)  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images", sub_dir)

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数
    # 使用的类型
    run_type = "thread"  # 多线程
    # run_type = "process"  # 多进程

    print(f"running use run_type={run_type}, max_workers:{max_workers}")

    if run_type == "thread":
        # 使用线程池控制程序执行
        with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            for root, _, files in os.walk(image_root):
                # 需要排除的目录
                if any(exclude_dir in root for exclude_dir in exclude_dirs):
                    continue

                for file in files:
                    # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
                    executor.submit(run_thread,
                                    *(root, file, save_root, image_root, txt_root, class_name_dict, CONF_THRESH_DICT))

        print("\nFinish ...")

    elif run_type == "process":
        # 一个进程处理多少图片
        max_file_num = 1000
        # 保存root和file的list
        root_file_list: List[Tuple] = list()

        # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
        pool = multiprocessing.Pool(processes=max_workers)

        # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
        for root, _, files in os.walk(image_root):
            # 需要排除的目录
            if any(exclude_dir in root for exclude_dir in exclude_dirs):
                continue

            for file in files:
                root_file_list.append((root, file))

                if len(root_file_list) > max_file_num:
                    # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                    pool.apply_async(run_process,
                                     (deepcopy(root_file_list), save_root, image_root, txt_root,
                                      class_name_dict, CONF_THRESH_DICT))
                    # 清除列表中的存储
                    root_file_list.clear()

        else:
            # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
            # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
            # 启动一个进程,开始处理当前list中的信息
            pool.apply_async(run_process, (deepcopy(root_file_list), save_root, image_root, txt_root,
                                           class_name_dict, CONF_THRESH_DICT))
            # 清除列表中的存储
            root_file_list.clear()

        # 关闭进程池
        pool.close()
        # 等待所有子进程执行结束
        pool.join()

        print("\nFinish ...")

    else:
        print("run_type should be thread or process.")


YOLOv5(You Only Look Once version 5)是一种流行的物体检测算法,它通常需要标注的数据集包括图片文件和与其对应的标签信息。数据集中,图片信息存储为.jpg或.png格式,而标签信息通常是以.txt格式保存的,每个行对应一张图片的一组边界框注解。 将YOLOv5txt格式数据换为XML(如Pascal VOC标准),可以分为以下几个步骤: 1. **理解格式**:YOLOv5 txt格式通常是这样的: ``` image_id x_center y_center width height class_name ``` 而VOC XML格式需要包含图像信息、对象信息、类别等详细内容。 2. **创建XML文件**: - 创建一个新的XML文件,例如`image_001.xml` - 标题(`<annotation>`)、源(`<source>`, `<size>`)、主图像(`<folder>`, `<filename>`)和时间戳(可选)作为XML文档的开始。 - 对于每一行数据,解析出图片ID(`<object>`的`@name`)、坐标 (`x`, `y`, `width`, `height`) 和类别(`class`), 添加到相应的元素中。 3. **编写脚本**: - 使用Python或其他编程语言,比如写一个循环遍历txt文件,读取每行数据并生成对应的XML结构。这里可以用正则表达式提取坐标和类别。 - 每次处理一行数据,创建一个新的`<object>`元素,并添加到`<annotation>`下的`<objects>`列表中。 4. **保存文件**: - 将XML数据写入文件,记得关闭文件句柄。 下面是一个简单的Python示例(假设已导入所需的库): ```python import os import xml.etree.ElementTree as ET def convert_yolov5_to_xml(txt_file, output_dir): # 创建根元素 annotation = ET.Element('annotation') # 遍历txt文件 with open(txt_file, 'r') as f: for line in f: # 解析每一行数据 img_id, *coords_and_class = line.strip().split() object_element = ET.SubElement(annotation, 'object') ET.SubElement(object_element, 'name').text = img_id.split('.')[0] # 图片ID bbox = ET.SubElement(object_element, 'bndbox') x, y, w, h = [int(float(coord)) for coord in coords_and_class[::2]] # 提取坐标 ET.SubElement(bbox, 'xmin').text = str(x) ET.SubElement(bbox, 'ymin').text = str(y) ET.SubElement(bbox, 'xmax').text = str(x + w) ET.SubElement(bbox, 'ymax').text = str(y + h) ET.SubElement(object_element, 'difficult').text = '0' # 设置难度为0 ET.SubElement(object_element, 'category').text = coords_and_class[-1] # 类别 # 构建文件路径 output_path = os.path.join(output_dir, f"{img_id}.xml") tree = ET.ElementTree(annotation) tree.write(output_path) # 使用函数 convert_yolov5_to_xml('labels.txt', 'output_voc') ``` 记得替换`labels.txt`为你实际的txt文件路径,以及指定输出目录。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值