Faster-RCNN-TensorFlow-Python3.5-master输出pr曲线计算AP值

tensorflow版本fasterrcnn模型评价指标都在lib/datasets/passcal_voc,voc_eval.
中。
首先下面在找到passcal_voc在开头加入这几句:

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from itertools import cycle
import pylab as pl

然后找到这个函数更改如下:

   def _do_python_eval(self, output_dir='output'):
        annopath = self._devkit_path + '\\VOC' + self._year + '\\Annotations\\' + '{:s}.xml'
        imagesetfile = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
            'ImageSets',
            'Main',
            self._image_set + '.txt')
        cachedir = os.path.join(self._devkit_path, 'annotations_cache')
        aps = []
        #加
        recs=[]
        precs=[]
        #结束
        # The PASCAL VOC metric changed in 2010
        use_07_metric = True if int(self._year) < 2010 else False
        print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        for i, cls in enumerate(self._classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
                use_07_metric=use_07_metric)
            aps += [ap]
            #加
            '''
            recs += [rec[-1]]
            precs += [prec[-1]]
            print('AP for {} = {:.4f}'.format(cls, ap))
            print('recall for {} = {:.4f}'.format(cls, rec[-1]))
            print('precision for {} = {:.4f}'.format(cls, prec[-1]))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
                 pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        print('Mean AP = {:.4f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('Results:')
            '''
            pl.plot(rec, prec, lw=2,
                    label='Precision-recall curve of class {} (area = {:.4f})'
                          ''.format(cls, ap))
            print(('AP for {} = {:.4f}'.format(cls, ap)))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
                pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        #加
        pl.xlabel('Recall')
        pl.ylabel('Precision')
        plt.grid(True)
        pl.ylim([0.0, 1.2])
        pl.xlim([0.0, 1.0])
        pl.title('Precision-Recall')
        pl.legend(loc="upper right")
        plt.show()
        print(('Mean AP = {:.4f}'.format(np.mean(aps))))
        print('~~~~~~~~')
        print('Results:')
        for ap in aps:
            print(('{:.3f}'.format(ap)))
        print(('{:.3f}'.format(np.mean(aps))))
        print('~~~~~~~~')
        print('')
        print('--------------------------------------------------------------')
        print('Results computed with the **unofficial** Python eval code.')
        print('Results should be very close to the official MATLAB eval code.')
        print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
        print('-- Thanks, The Management')
        print('--------------------------------------------------------------')

这个文件的另外一个函数注释掉这几行(为例生成检测后的每一类的预测框的文本文件)

   def evaluate_detections(self, all_boxes, output_dir):
        self._write_voc_results_file(all_boxes)

        self._do_python_eval(output_dir)
        if self.config['matlab_eval']:
            self._do_matlab_eval(output_dir)
        #if self.config['cleanup']:
        #    for cls in self._classes:
        #        if cls == '__background__':
        #            continue
        #        filename = self._get_voc_results_file_template().format(cls)
        #       os.remove(filename)

voc_eval.py文件中做如下更改:

def parse_rec(filename):#读取标注xml文件
    """ Parse a PASCAL VOC xml file """
    tree = ET.parse(''+filename)
    objects = []#./data/VOCdevkit2007/VOC2007/Annotations/
    for obj in tree.findall('object'):
        obj_struct = {}
        obj_struct['name'] = obj.find('name').text
        obj_struct['pose'] = obj.find('pose').text
        obj_struct['truncated'] = int(obj.find('truncated').text)
        obj_struct['difficult'] = int(obj.find('difficult').text)
        bbox = obj.find('bndbox')
        obj_struct['bbox'] = [int(bbox.find('xmin').text),
                              int(bbox.find('ymin').text),
                              int(bbox.find('xmax').text),
                              int(bbox.find('ymax').text)]
        objects.append(obj_struct)

    return objects

(让objects之前为空列表)
在faster-rcnn-tensorflow-python3.5-master文件夹下新建
test-net.py.

# !/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os

import tensorflow as tf
from lib.nets.vgg16 import vgg16
from lib.datasets.factory import get_imdb
from lib.utils.test import test_net

# NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_5000.ckpt',)}  # 训练输出模型
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN test')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc')
    args = parser.parse_args()
    return args
if __name__ == '__main__':
    args = parse_args()
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])  # 模型路径
    # 获得模型文件名称
    filename = (os.path.splitext(tfmodel)[0]).split('\\')[-1]
    filename = 'default' + '/' + filename
    imdb = get_imdb("voc_2007_test")  # 得到
    imdb.competition_mode('competition mode')
    if not os.path.isfile(tfmodel + '.meta'):
        print(tfmodel)
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))
    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16(batch_size=1)
    # elif demonet == 'res101':
    # net = resnetv1(batch_size=1, num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture(sess, "TEST", 8,  # 记得修改第3个参数为:类别数量+1
                            tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)
    print('Loaded network {:s}'.format(tfmodel))
    print(filename)
    test_net(sess, net, imdb, filename, max_per_image=100)
    sess.close()

(记得修改相应的路径和分类的类别数)
点击运行即可生成几个文本文件和pr曲线
在这里插入图片描述
在这里插入图片描述

参考博客1
参考博客2

  • 6
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 28
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值