gluoncv之源码解析

本文是GluonCV官方源码解析系列的一部分,主要聚焦于VOC格式的mAP(平均精度)计算。文章详细介绍了如何利用网络输出和标签进行数据处理,并通过调用`eval_metric.update()`更新mAP值,以及`eval_metric.get()`获取计算结果。重点解析了VOC07MApMetric类如何继承并实现 EvalMetric 类中的AP计算方法。
摘要由CSDN通过智能技术生成

系列文章

一、数据集预处理.
二、voc格式的mAP计算.



前言

本项为gluoncv官方源码解析


一、voc格式的mAP计算源码解析

先介绍如何使用:
根据网络的输出和label压入数组中保存,通过调用eval_metric.update()函数,来更新mAP值。
eval_metric.get()函数返回mAP的计算结果。

# 声明一个评价函数对象,传入之后的validate函数中
val_metric = voc_metric.VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)

def validate(net, val_data, ctx, eval_metric):
    """Test on validation dataset."""
    eval_metric.reset()
    # set nms threshold and topk constraint
    net.set_nms(nms_thresh=0.45, nms_topk=400)
    mx.nd.waitall()
    net.hybridize()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
        det_bboxes = []
        det_ids = []
        det_scores = []
        gt_bboxes = []
        gt_ids = []
        gt_difficults = []
        for x, y in zip(data, label):
            # get prediction results
            ids, scores, bboxes = net(x)
            det_ids.append(ids)
            det_scores.append(scores)
            # clip to image size
            det_bboxes.append(bboxes.clip(0, batch[0].shape[2]))
            # split ground truths
            gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
            gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
            gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)

        # update metric
        eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults)
    return eval_metric.get()

VOC07MApMetric类继承自EvalMetric类,只是重写了计算mAP的方法,具体每个类别的AP计算方法,在EvalMetric类中。

"""Pascal VOC Detection evaluation."""
from __future__ import division

from collections import defaultdict
import numpy as np
import mxnet as mx
from yolov3.bbox import bbox_iou


class VOCMApMetric(mx.metric.EvalMetric):
    """
    Calculate mean AP for object detection task

    Parameters:
    ---------
    iou_thresh : float
        IOU overlap threshold for TP
    class_names : list of str
        optional, if provided, will print out AP for each class
    """
    def __init__(self, iou_thresh=0.5, class_names=None):
        super(VOCMApMetric, self).__init__('VOCMeanAP')
        if class_names is None:
            self.num = None
        else:
            assert isinstance(class_names, (list, tuple))
            for name in class_names:
                asse
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值