需要用到shapely这个库,我的环境是 Win10 x64,python3.7,该库可在以下网址下载:
https://www.lfd.uci.edu/~gohlke/pythonlibs/
前面的博客里写了利用DOTA_Devkit计算AP。有时候我们还需要知道特定置信度下的检出率recall和准确率precision,DOTA_Devkit就不大好用了,想来不如自己写个脚本计算。
先贴出计算公式,TP为检测正确的个数,FP为检测错误的数量,GT为标签个数:
recall = TP / GT
precision = TP / TP + FP
需要用来计算的结果文件还是跟DOTA_Devkit需要的一样,具体的代码和解释如下:
# *_* coding : UTF-8 *_*
# 开发人员: csu·pan-_-||
# 开发时间: 2020/11/19 10:41
# 文件名称: evalRRPN.py
# 开发工具: PyCharm
# 功能描述: 对Detectron2里的RRPN任意多边形的目标检测结果进行定量分析
# 计算检出率 recall 和准确率 precision
"""
目前的水平框分析工具不适用于旋转框
根据任意多边形的预测结果求取与label的iou
每个类有一个txt文档的预测框,内容如下:
00026 0.9996 716 241 760 253 717 412 673 400
00026 0.9993 635 243 679 257 630 414 586 401
图片名 置信度 四点坐标
一个图片一个txt格式的标签:
imagesource:GoogleEarth
gsd:0.115726939386
126 179 229 269 215 286 112 195 ship 0
136 164 240 254 227 269 123 179 ship 0
四点坐标 类别 difficult
"""
import numpy as np
import shapely
from shapely.geometry import Polygon, MultiPoint # 多边形计算的库
import os
# 标签和预测结果的路径:
labelTxtPath = r'E:\Projects\DOTA_devkit-master\warship_result\txts_dota'
preTxtPath = r'E:\Projects\DOTA_devkit-master\warship_result\infer4'
classname = ['plane','ship'] # 目标的类别名称
label_num = [] # 把类别名字存储在列表里,方便统计每个类的label数量
TP,FP = [],[] # 把TP和FP的类别名字存储在列表里
iouthresh = 0.5 # 设置预测框和标签需要的iou阈值,超过它才能被认为是TP
confthresh = 0.25 # 设置预测框置信度的阈值,超过它才能被认为是TP
# 求每个类的recall和precision
def rec_prec_per_class(pretxt):
"""
:param pretxt: 每个类别的预测txt文件名,如 carrier.txt
:return: 将遍历得到的标签存储在label_num列表里,计算得到的TP和FP存在列表里
"""
pre_class = pretxt.replace('.txt','') # 得到文档的类别名称
isFP = True
with open(os.path.join(preTxtPath,pretxt), "r") as f:
for line in f.readlines():
line = line.strip('\n') # 去掉列表中每一个元素的换行符
pre_list = line.split(' ')
# ['00026', '0.9996', '716' , '241', '760', '253', '717', '412', '673', '400']
prebox = pre_list[2:]
# ['716' , '241', '760', '253', '717', '412', '673', '400']
prebox = list(map(int, prebox))
# [716, 241, 760, 253, 717, 412, 673, 400]
# 对每一行预测标签,循环遍历label标签求iou
with open(os.path.join(labelTxtPath,pre_list[0] + '.txt'), "r") as fl:
data = fl.readlines()[2:] # 去掉前面两行无用信息imagesource、gsd
for line_l in data:
line_l = line_l.strip('\n') # 去掉列表中每一个元素的换行符
label_list = line_l.split(' ')
# ['126', '179', '229', '269', '215', '286', '112', '195', 'warship', '0']
label_box = label_list[0:8] # 取前面四点坐标
label_box = list(map(int, label_box)) # 转换为int型
if label_list[8] == pre_class:
iou = compute_IOU(prebox, label_box)
# iou和置信度都大于阈值才认为是检测正确
if iou >= iouthresh and float(pre_list[1]) >= confthresh:
TP.append(pre_class) # TP添加一个
isFP = False
break
if isFP:
FP.append(pre_class) # 没有iou大于阈值的,FP添加一个
else:
isFP = True # 遍历完一个标签文件,将isFP设置为True
print('succeed process: %s' % pretxt)
# 计算每个类别标签的数量
def count_numpos(labelTxtPath):
"""
:param labelTxtPath: 每个类别的预测txt文件名,如 carrier.txt
:return: 无。 将遍历得到的标签存储在label_num列表里
"""
labelTxts = os.listdir(labelTxtPath)
for labeltxt in labelTxts:
with open(os.path.join(labelTxtPath, labeltxt), "r") as f:
data = f.readlines()[2:] # 去掉前面两行无用信息imagesource、gsd
for line_l in data:
line_l = line_l.strip('\n') # 去掉列表中每一个元素的换行符
label_list = line_l.split(' ')
# ['126', '179', '229', '269', '215', '286', '112', '195', 'warship', '0']
label_num.append(label_list[8]) # 添加标签的类别至列表,以便后面统计各个类别的目标数量
# 求任意四边形iou
def compute_IOU(line1,line2):
# 四边形四个点坐标的一维数组表示,[x,y,x,y....]
# 如:line1 = [728, 252, 908, 215, 934, 312, 752, 355]
# 返回iou的值,如 0.7
line1_box = np.array(line1).reshape(4, 2) # 四边形二维坐标表示
# 凸多边形与凹多边形
poly1 = Polygon(line1_box)
# .convex_hull # python四边形对象,会自动计算四个点,最后四个点顺序为:左上 左下 右下 右上 左上
line2_box = np.array(line2).reshape(4, 2)
# 凸多边形与凹多边形
poly2 = Polygon(line2_box)
union_poly = np.concatenate((line1_box, line2_box)) # 合并两个box坐标,变为8*2
if not poly1.intersects(poly2): # 如果两四边形不相交
iou = 0
else:
try:
inter_area = poly1.intersection(poly2).area # 相交面积
union_area = MultiPoint(union_poly).convex_hull.area
if union_area == 0:
iou = 0
else:
iou = float(inter_area) / union_area
except shapely.geos.TopologicalError:
print('shapely.geos.TopologicalError occured, iou set to 0')
iou = 0
return iou
if __name__ == '__main__':
preTxts = os.listdir(preTxtPath)
for pretxt in preTxts:
rec_prec_per_class(pretxt)
count_numpos(labelTxtPath) # 运行一遍获得每个类别的标签数量
for class_i in classname:
print('class_name:',class_i)
numpos = label_num.count(class_i) # 每个类别的目标总数
print('num_object: ',numpos)
recall = TP.count(class_i) / numpos
precision = TP.count(class_i) / (TP.count(class_i) + FP.count(class_i))
print('recall: %.4f'%recall)
print('precision: %.4f'%precision)
print()
结语: 感觉计算的过程时间复杂度比较高,先遍历了一遍标签文件求出每个类别的gt个数,再双重循环求预测框和标签的iou,一下子想不到更好的优化办法了,先把功能实现吧,欢迎批评指正。