yolo目标检测,检测结果出现多框问题,需要对数据集标签聚类,生成数据集对应的anchor

需要使用xml标签来转换,转换生成anchor以后,使用yolo进行训练的时候,将所使用的models(yolov5s.yaml)中的anchors替换。注意:生成的anchors需要四舍五入。

将anchors替换到models的yaml中。

import os
import numpy as np
import xml.etree.cElementTree as et
from kmeans import kmeans, avg_iou

FILE_ROOT = "/mnt/"  # 根路径
ANNOTATION_ROOT = "Annotations"  # 数据集xml标签文件夹路径
ANNOTATION_PATH = FILE_ROOT + ANNOTATION_ROOT    #xml数据集文件夹路径,上两行做了个拼接

ANCHORS_TXT_PATH = "/mnt/fall_personanchor.txt"  # anchors文件保存位置,需自己先创建一个txt

CLUSTERS = 9        #这个9不用动
CLASS_NAMES = ['down', 'person']  # 类别名称,根据自己实际类别修改


def load_data(anno_dir, class_names):
    xml_names = os.listdir(anno_dir)
    boxes = []
    for xml_name in xml_names:
        xml_pth = os.path.join(anno_dir, xml_name)
        print(f'{xml_name}')
        try:
            tree = et.parse(xml_pth)
        except et.ParseError as e:
            print(f"Error parsing {xml_name}: {e}")
            continue

        width = float(tree.findtext("./size/width"))
        height = float(tree.findtext("./size/height"))
        if width == 0 or height == 0:
            continue

        for obj in tree.findall("./object"):
            cls_name = obj.findtext("name")
            if cls_name in class_names:
                xmin = float(obj.findtext("bndbox/xmin")) / width
                ymin = float(obj.findtext("bndbox/ymin")) / height
                xmax = float(obj.findtext("bndbox/xmax")) / width
                ymax = float(obj.findtext("bndbox/ymax")) / height

                box = [xmax - xmin, ymax - ymin]
                boxes.append(box)
            else:
                continue
    return np.array(boxes)


if __name__ == '__main__':

    anchors_txt = open(ANCHORS_TXT_PATH, "w")

    train_boxes = load_data(ANNOTATION_PATH, CLASS_NAMES)

    if len(train_boxes) == 0:
        print("No valid data found. Exiting...")
    else:
        count = 1
        best_accuracy = 0
        best_anchors = []
        best_ratios = []

        for i in range(10):  # 迭代次数,可以根据需要调整
            anchors_tmp = []

            clusters = kmeans(train_boxes, k=CLUSTERS)
            idx = clusters[:, 0].argsort()
            clusters = clusters[idx]

            for j in range(CLUSTERS):
                anchor = [round(clusters[j][0] * 640, 2), round(clusters[j][1] * 640, 2)]
                anchors_tmp.append(anchor)
                print(f"Anchors:{anchor}")

            temp_accuracy = avg_iou(train_boxes, clusters) * 100
            print("Train_Accuracy:{:.2f}%".format(temp_accuracy))

            ratios = np.around(clusters[:, 0] / clusters[:, 1], decimals=2).tolist()
            ratios.sort()
            print("Ratios:{}".format(ratios))
            print(20 * "*" + " {} ".format(count) + 20 * "*")

            count += 1

            if temp_accuracy > best_accuracy:
                best_accuracy = temp_accuracy
                best_anchors = anchors_tmp
                best_ratios = ratios

        anchors_txt.write("Best Accuracy = " + str(round(best_accuracy, 2)) + '%' + "\r\n")
        anchors_txt.write("Best Anchors = " + str(best_anchors) + "\r\n")
        anchors_txt.write("Best Ratios = " + str(best_ratios))

    anchors_txt.close()

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lucky169

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值