功能性模块:(7)检测性能评估模块
一、模块介绍
其实每个算法的好坏都是有对应的评估标准的,如果你和老板说检测算法好或者不好,哈哈哈,那必然就是悲剧了。好或者不好是一个定性的说法,对于实际算法来说,到底怎么样算法算好?怎么样算法算不好?这些应该是有个定量的标准。对于检测来说,可能最常用的几个评价指标就是precision(查准率,就是你检测出来的目标有多少是真的目标),recall(查全率,就是实际的目标你的算法能检测出来多少),还有ap,map等。本篇博客其实就是让小伙伴们对自己的检测模型心里有一个底,换句话说这个模型你训练出来到底咋样?
二、代码实现
import numpy as np
import os
def voc_ap(rec, prec, use_07_metric=False):
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
the VOC 07 11-point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def ComputeMAP(gt_root, predict_root, OVTHRESH=0.5):
"""
:param gt_root: 生成gt文件的根目录
:param predict_root: 算法跑出的根目录
:param overthresh: 设置的阈值
:return:
"""
# 获取所有的文件
files_gt = os.listdir(gt_root)
files_pred = os.listdir(predict_root)
files_gt.sort()
# 这个变量的目的是什么?保存gt中真正的框的数量
npos = 0
class_recs = {}
# 遍历所有gt文件
for file_gt in files_gt:
img_name = os.path.splitext(os.path.basename(file_gt))[0]
file_gt = os.path.join(gt_root, os.path.basename(file_gt))
print("*" * 80)
print("img name is: ", img_name)
print("gt file is: ", file_gt)
# 处理gt文件
with open(file_gt, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
bbox = np.array([[float(z) for z in x[:]] for x in splitlines])
print("bbox is: \n", bbox)
det = [False] * len(bbox)
npos = npos + len(bbox)
class_recs[img_name] = {'bbox': bbox, 'det': det}
print("*" * 80)
print("Total npos is: ", npos)
# 遍历所有的检测结果
img_ids = []
confidence = []
BB = []
for file_pred in files_pred:
img_name = os.path.splitext(os.path.basename(file_pred))[0]
file_pred = os.path.join(pred_root, os.path.basename(file_pred))
print("*" * 80)
print("img_name is: ", img_name)
print("pred file is: ", file_pred)
with open(file_pred, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(" ") for x in lines]
confidence_p = [float(x[0]) for x in splitlines]
bbox_p = [[float(z) for z in x[1:]] for x in splitlines]
# 根据confidence_p的长度,复制对应的img_name的str,生成对应长度的list
# ['20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H']
img_ids.extend([img_name] * len(confidence_p))
confidence.extend(confidence_p)
BB.extend(bbox_p)
print(img_ids)
print(confidence)
print(BB)
confidence = np.array(confidence)
BB = np.array(BB)
print("*" * 80)
print("All files loaded!")
# 按照confidence的降序进行排列
sorted_idx = np.argsort(-confidence)
print("sorted idx is: ", sorted_idx)
BB = BB[sorted_idx, :]
img_ids = [img_ids[x] for x in sorted_idx]
# 计算对应的TPs 和 FPs
nd = len(img_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
wrong_count = 0
for d in range(nd):
print("We are now test: ", img_ids[d])
# 取出对应图像的gt
R = class_recs[img_ids[d]]
# 检测的结果
bb = BB[d, :].astype(float)
# 假设重叠面积初始为-inf
ovmax = -np.inf
BBGT = R['bbox'].astype(float)
print("bb: \n ", bb)
print("BBGT: \n", BBGT)
print("BBGT size is: ", BBGT.size)
if BBGT.size > 0:
# 计算覆盖的部分
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
# 计算交叉的面积
inters = iw * ih
# 计算iou吧
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.)
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
- inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
print("overlaps is: ", overlaps)
print("ovmax is: ", ovmax)
print("jmax is: ", jmax)
if ovmax > OVTHRESH:
# 如果检测的这个标记还没有激活,默认是False
if not R['det'][jmax]:
tp[d] = 1.
R['det'][jmax] = 1
else:
fp[d] = 1.
wrong_count += 1
else:
fp[d] = 1.
wrong_count += 1
np.set_printoptions(threshold=np.inf)
# 计算 precision 和 recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
print("fp is: ", fp)
print("tp is: ", tp)
# 召回率(查全率)
rec = tp / float(npos)
# 精确率(查准率)
prec = tp / np.maximum(tp + fp, np.finfo(np.float).eps)
ap = voc_ap(rec, prec, False)
print("ap is: ", ap)
print("*" * 80)
print("RESULTS: \n")
print("Total %d images, %d objects" % (len(files_gt), npos))
print("Detected Correct: %d, Wrong: %d, Miss: %d under IOU: %f"
% (nd - wrong_count, wrong_count, npos - (nd - wrong_count), OVTHRESH))
print("Accuracy %f, Recall %f, Average Precision %f"
% (float(nd - wrong_count) / (nd), float(nd - wrong_count) / (npos), ap))
# 记录漏检的文件
f = open('./lost.txt', 'w')
for k, v in class_recs.items():
if False in v['det']:
f.write(str(k) + '.jpg' + '\n')
f.close()
if __name__ == "__main__":
gt_root = './mini_test/gt/'
pred_root = './mini_test/res/'
ComputeMAP(gt_root, pred_root)
LZ就不详细讲代码了,注释已经很详细了,主要是你的gt应该是什么样子的呢?
- 命名标准:img_name.txt
- gt格式:
# x1 y1 x2 y2
965 209 1040 329
- res格式:
# score x1 y1 x2 y2
0.9999481 962 222 1043 331
0.9999091 635 251 747 412
0.9783503 1795 340 1836 402
0.57386667 1730 305 1748 337
这个是结果展示,代码中LZ为了清晰加了非常多的打印,谁让云存储不稳定呢,动不动图片就被损坏了,哭唧唧。。。
ps:最近疫情反弹的厉害,谁能想到新冠肺炎居然坚持了一年,国外疫情也是指数性增长,这算是人类的灾难,也许多年后在看现在,又会有不一样的体会。珍惜当下,爱惜生命!