tf-faster-rcnn指南(六)——绘制Precision-recall曲线

目录

本文主要参考:基于Faster RCNN,如何画出类别的Precision-recall曲线

一、什么是AP,Recall,Precision
1、Precision和Recall
  Precision,准确率/查准率。Recall,召回率/查全率。这两个指标分别以两个角度衡量分类系统的准确率。
  
  例如,有一个池塘,里面共有1000条鱼,含100条鲫鱼。机器学习分类系统将这1000条鱼全部分类为“不是鲫鱼”,那么准确率也有90%(显然这样的分类系统是失败的),然而查全率为0%,因为没有鲫鱼样本被分对。这个例子显示出一个成功的分类系统必须同时考虑Precision和Recall,尤其是面对一个不平衡分类问题。
  在二元分类模型的预测结果有四种,从数学公式理解:
  混淆矩阵
  True Positive(真正,TP):将正类预测为正类数
  True Negative(真负,TN):将负类预测为负类数
  False Positive(假正,FP):将负类预测为正类数误报 (Type I error)
  False Negative(假负,FN):将正类预测为负类数→漏报 (Type II error)
  Precision和Recall的计算公式分别为:
P r e c i s i o n = T P T P + F P Precision=\frac {TP} {TP+FP} Precision=TP+FPTP
R e c a l l = T P T P + F N Recall=\frac {TP} {TP+FN} Recall=TP+FNTP
平均精度AP(average precision):就是PR曲线下的面积
二、修改代码参数,提高AP值
  源代码中默认的计算AP时的门限为0.5,可将这一数值降低,从而提高AP的结果。
1、打开/lib/datasets/pascal_voc.py
2、修改_do_python_eval()函数
  找到代码位置

filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
  filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
  use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
aps += [ap]

filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
改为filename, annopath, imagesetfile, cls, cachedir, ovthresh=0,
三、画PR曲线
1、修改/lib/datasets/pascal_voc.py
  在文件头部加入:

import matplotlib.pyplot as plt
import pylab as pl
from sklearn.metrics import precision_recall_curve
from itertools import cycle

2、修改_do_python_eval()函数
  在aps += [ap]后加入:

pl.plot(rec, prec, lw=2, 
                    label='{} (AP = {:.4f})'
                          ''.format(cls, ap))

这句最重要,就是用来画图的。
再加入:

 pl.xlabel('Recall')
 pl.ylabel('Precision')
 plt.grid(True)
 pl.ylim([0.0, 1.05])
 pl.xlim([0.0, 1.05])
 pl.title('Precision-Recall')
 pl.legend(loc="lower left")     
 plt.savefig('./PR.jpg')
 plt.show()

修改完后的整个函数如下所示:

def _do_python_eval(self, output_dir = 'output'):
        annopath = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
            'Annotations',
            '{:s}.xml')
        imagesetfile = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
            'ImageSets',
            'Main',
            self._image_set + '.txt')
        cachedir = os.path.join(self._devkit_path, 'annotations_cache')
        aps = []
        # The PASCAL VOC metric changed in 2010
        use_07_metric = True if int(self._year) < 2010 else False
        print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        for i, cls in enumerate(self._classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0,
                use_07_metric=use_07_metric)
            aps += [ap]
            pl.plot(rec, prec, lw=2, 
                    label='{} (AP = {:.4f})'
                          ''.format(cls, ap))
            print('AP for {} = {:.4f}'.format(cls, ap))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
                cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        pl.xlabel('Recall')
        pl.ylabel('Precision')
        plt.grid(True)
        pl.ylim([0.0, 1.05])
        pl.xlim([0.0, 1.05])
        pl.title('Precision-Recall')
        pl.legend(loc="lower left")     
        plt.savefig('./PR.jpg')
        plt.show()
        print('Mean AP = {:.4f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('Results:')
        for ap in aps:
            print('{:.3f}'.format(ap))
        print('{:.3f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('')
        print('--------------------------------------------------------------')
        print('Results computed with the **unofficial** Python eval code.')
        print('Results should be very close to the official MATLAB eval code.')
        print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
        print('-- Thanks, The Management')
        print('--------------------------------------------------------------')

四、运行测试指令
  这里依旧以vgg16为例

./experiments/scripts/train_faster_rcnn.sh 01 pascal_voc vgg16

在这里插入图片描述

  • 6
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值