Faster R-CNN的测试过程分析

引言

  这篇博客里,我主要分析一下faster rcnn的测试过程是如何实现的。每个小结我都会以某个py文件的名字作为标题,表示以下内容是对此文件的分析。

test.py

  test.py是用来测试网络的准确度的主要代码,下面我来分析下这个文件里面最主要的函数test_net()。
  test_net的输入是faster r-cnn网络,图片等信息,输出的是对这些图片里面物体进行预测的准确率。

def test_net(net, imdb, max_per_image=100, thresh=0.05, vis=False):
    """Test a Fast R-CNN network on an image database."""
    num_images = len(imdb.image_index)
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(imdb.num_classes)]
    # 定义程序输出的路径
    output_dir = get_output_dir(imdb, net)

    # timers
    _t = {
  'im_detect' : Timer(), 'misc' : Timer()}

    if not cfg.TEST.HAS_RPN:
        roidb = imdb.roidb
    # 遍历每一张图片
    for i in xrange(num_images):
        # filter out any ground truth boxes
        if cfg.TEST.HAS_RPN:
            box_proposals = None
        else:
            # roidb里可能有ground-truth的rois,这会影响检测结果(使得结果编号),所以我们要把这些bbox剔除
            box_proposals = roidb[i]['boxes'][roidb[i]['gt_classes'] == 0]
        # 读取图片
        im = cv2.imread(imdb.image_path_at(i))
        _t['im_detect'].tic()
        # 得到这张图片的预测bbox和bbox的得分,具体数据类型如下
        # scores (ndarray): R x K array of object class scores 
        # (K includes background as object category 0)
        # boxes (ndarray): R x (4*K) array of predicted bounding boxes
        scores, boxes = im_detect(net, im, box_proposals)
        _t['im_detect'].toc()

        _t['misc'].tic()
        # 对于每张图片,从类别1开始统计预测结果(类别0是背景类)
        for j in xrange(1, imdb.num_classes):
            # 取出score大于某个阈值的下标
            inds = np.where(scores[:, j] > thresh)[0]
            cls_scores = scores[inds, j]
            cls_boxes = boxes[inds, j*4:(j+1)*4]
            # 将scores与bbox合在一起,得到dets
            cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            keep = nms(cls_dets, cfg.TEST.NMS)
            cls_dets = cls_dets[keep, :]
            # vis==True时,将框框与图片可视化,显示在屏幕上
            if vis:
                vis_detections(im, imdb.classes[j], cls_dets)
            all_boxes[j][i] = cls_dets

        # 将每张图片的检测个数限制在max_per_image之内
        if max_per_image > 0:
            image_scores = np.hstack([all_boxes[j][i][:, -1]
                                      for j in xrange(1, imdb.num_classes)])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in xrange(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                    all_boxes[j][i] = all_boxes[j][i][keep, :]
        _t['misc'].toc()

        print 'im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
              .format(i + 1, num_images, _t['im_detect'].average_time,
                      _t['misc'].average_time)
    # 将检测结果保存
    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)
    # 评估检测结果
    print 'Evaluating detections'
    imdb.evaluate_detections(all_boxes, output_dir)

   那么问题来了,这个imdb是个什么东西,它是怎么评估测试准确率的?
   找到imdb的调用位置,一直找到它的出生地,发现在pascal_voc.py里面,下面分析下这段代码。

pascal_voc.py中MAp评测的相关代码

  这个代码写的是关于imdb的操作。先看看imdb.evaluate_detections():

def evaluate_detections(self, all_boxes, output_dir):
  self._write_voc_results_file(all_boxes)
  # pyhon版本的评测(嗯再看看这个里面怎么实现评测的)
  self._do_python_eval(output_dir)
  if self.config['matlab_eval']:
      self._do_matlab_eval(output_dir)
  if self.config['cleanup']:
      for cls in self._classes:
          if cls == '__background__':
              continue
          filename = self._get_voc_results_file_template().format(cls)
          os.remove(filename)

  进入self._do_python_eval,发现它还调用了一个voc_eval,原来这个函数才是计算准确率的。下面再进入这个函数看看。

rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
                use_07_metric=use_07_metric)

  进入voc_eval看看:
  它的输入是dets文件(scores与bboxes),annopath(ground-truth的标注文件),imagesetfile,cls(某个特定类别),cachedir(缓存的路径),use_07_metric(评价的标准)。
  它

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值