这篇博客介绍测试过程中的评价函数,在MXNet框架下都可以通过继承mx.metric.EvalMetric类进行实现。
该项目的evaluate文件夹下的一个脚本eval_metric.py定义了关于测试过程中的评价函数。这个脚本主要涉及两个类:MApMetric和VOC07MApMetric,后者是继承前者并重写了一些方法得到的,因此MApMetric类是核心。这两者都是用来计算object detection算法中的MAp(Mean avearage precision)。
import mxnet as mx
import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class MApMetric(mx.metric.EvalMetric):
"""
Calculate mean AP for object detection task
Parameters:
---------
ovp_thresh : float
overlap threshold for TP
use_difficult : boolean
use difficult ground-truths if applicable, otherwise just ignore
class_names : list of str
optional, if provided, will print out AP for each class
pred_idx : int
prediction index in network output list
roc_output_path
optional, if provided, will save a ROC graph for each class
tensorboard_path
optional, if provided, will save a ROC graph to tensorboard
"""
# __init__中还是执行常规的重置操作:reset()和一些赋值操作。
def __init__(self, ovp_thresh=0.5, use_difficult=False, class_names=None,
pred_idx=0, roc_output_path=None, tensorboard_path=None):
super(MApMetric, self).__init__('mAP')
if class_names is None:
self.num = None
else:
assert isinstance(class_names, (list, tuple))
for name in class_names:
assert isinstance(name, str), "must provide names as str"
num = len(class_names)
self.name = class_names + ['mAP']
self.num = num + 1
self.reset()
self.ovp_thresh = ovp_thresh
self.use_difficult = use_difficult
self.class_names = class_names
self.pred_idx = int(pred_idx)
self.roc_output_path = roc_output_path
self.tensorboard_path = tensorboard_path
def save_roc_graph(self, recall=None, prec=None, classkey=1, path=None, ap=None):
if not os.path.exists(path):
os.mkdir(path)
plot_path = os.path.join(path, 'roc_'+self.class_names[classkey])
if os.path.exists(plot_path):
os.remove(plot_path)
fig = plt.figure()
plt.title(self.class_names[classkey])
plt.plot(recall, prec, 'b', label='AP = %0.2f' % ap)
plt.legend(loc='lower right')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('Precision')
plt.xlabel('Recall')
plt.savefig(plot_path)
plt.close(fig)
def reset(self):
"""Clear the internal statistics to initial state."""
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
self.records = dic