目标检测中具体求解AP(average precision)的代码详解
之前笔记记录过分类模型常用模型评价指标详解的文章https://blog.csdn.net/qq_40728805/article/details/103829881
本篇文章是基于yolov3中的具体代码求解AP的过程,推荐看看目标检测评价指标https://github.com/rafaelpadilla/Object-Detection-Metrics
import tqdm
import torch
import numpy as np
def ap_per_class(tp, conf, pred_cls, target_cls):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (list). 若元素为0表示索引对应的样本框是负样本框,为1表示索引对应的样本框是正样本框
conf: Objectness value from 0-1 (list). 网络预测输出的置信度,理解为概率
pred_cls: Predicted object classes (list). 预测的类别
target_cls: True object classes (list). gt 类别
# Returns
The average precision as computed in py-faster-rcnn.
"""
# Sort by objectness, 降序排列返回数据对应的索引
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes 对类别去重,因为计算ap是对每类进行
unique_classes = np.unique(target_cls)
# Create Precision-Recall curve and compute AP for each class
ap, p, r = [], [], []
for c in tqdm.tqdm(unique_classes, desc="Computing AP"):
i = pred_cls &