如何计算DOTA格式数据集的Recall和Precision

简介

在 DOTA 格式的数据集中,包含多种目标类型(如车辆、建筑物、桥梁等),不同类型目标的检测性能评估时,Recall(召回率) 和 Precision(精确率) 是两个最常用的评价指标。本文将介绍如何基于 DOTA 数据集计算这些指标,并展示一段通用的代码来帮助大家进行精度评估。

背景

DOTA(Object Detection in Aerial Images)数据集广泛应用于遥感图像目标检测任务中。每张图像包含多个目标,每个目标由一个旋转边界框(Rotated Bounding Box, RBox)标注。模型的输出一般是pkl格式的预测旋转框,包含预测的目标位置、目标类别和置信度。通过计算真实框和预测框之间的匹配情况,可以得到目标检测任务中的重要评估指标——召回率和精确率。

评估过程

1. 数据准备

首先,需要准备以下几部分数据:

  • 人工标注数据(ground truth): 包含图像中的真实目标旋转框、类别标签等信息,通常以 .txt 文件存储,格式为:
[x1, y1, x2, y2, x3, y3, x4, y4, 类别标签, 困难度]
  • 模型预测结果: 每个预测结果包含一个旋转框、类别标签和置信度分数,通常以 .pkl 文件存储。注意:模型预测结果文件与人工标注数据文件的前缀名需要相同。

2. 核心代码

我们通过以下步骤进行Recall和Precision的计算:

  • 旋转框转换: 将标注的多边形坐标转换为旋转框格式,以便于计算IoU。
  • IoU计算: 通过 mmcv 中的 box_iou_rotated 函数来计算IoU。
  • 多线程处理: 使用多线程加速每个文件的处理。
  • 精确率和召回率计算: 按照每个类别分别计算TP、FP、FN。

3. 通用代码

import os
from multiprocessing import get_context
import numpy as np
import torch
import mmcv
import pandas as pd
from mmcv.ops import box_iou_rotated
from concurrent.futures import ThreadPoolExecutor, as_completed

# 解析pkl文件中的每个类型的检测框数据
def parse_rotated_model_dota(model_dotas, class_map_cleaned):
    """
    解析模型检测的旋转框数据。
    返回格式:[(coords, dota_type, score), ...]
    """
    results = []
    for i, dota_type_res in enumerate(model_dotas):
        dota_type = class_map_cleaned[i]
        if isinstance(dota_type_res, np.ndarray) and dota_type_res.shape[1] == 6:
            # 处理检测框,假设每行包含 [cx, cy, w, h, theta, score]
            for bbox_with_score in dota_type_res:
                coords = bbox_with_score[:5]  # 5个参数
                score = bbox_with_score[5]    # 置信度得分
                results.append((coords, dota_type, score))
        else:
            print(f"  Unknown format for dota type {dota_type}: {dota_type_res}")
    return results


def polygon_to_rotated_box(polygon):
    """
    将8参数多边形(四个点的坐标)转换为5参数旋转框。
    """
    # 将多边形顶点转换为numpy数组
    poly_points = np.array(polygon, dtype=np.float32).reshape(-1, 2)
    # 获取最小外接矩形
    rect = cv2.minAreaRect(poly_points)
    (cx, cy), (w, h), theta = rect
    # OpenCV返回的角度是负角度,需要转换成正角度
    if w < h:
        w, h = h, w
        theta += 90
    theta = np.deg2rad(theta)  # 将角度转换为弧度
    return cx, cy, w, h, theta


def parse_dota_data(file_path, class_map):
    """
    解析文件中的标签数据。
    """
    dotas = []
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            coords = tuple(map(float, parts[:8]))  # 假设坐标存储在前8个元素
            dota_type = parts[8]
            if dota_type in class_map:
                dotas.append((coords, dota_type))
    return dotas 

def calculate_recall_precision(class_map, dataset, results, filenames, tp_iou_thr=0.5, score_thr=0.0, max_workers=30):
    """
    通用的DOTA数据集的Recall和Precision计算函数
    Args:
        class_map (list): 目标类别名称列表(如:['small_vehicle', 'large_vehicle', ...])。
        dataset (list): 人工标注的数据,格式为[(filename, coords, class_name, difficulty), ...],其中
                        coords 是旋转框坐标,class_name 是目标类别,difficulty 是难度标识。
        results (dict): 检测结果,格式为{filename: [(coords, class_name, score), ...]},coords 是旋转框,
                        class_name 是预测的类别,score 是检测模型给出的置信度。
        filenames (list): 包含每个文件名的列表,用于遍历和处理每个文件的数据。
        tp_iou_thr (float): 用于判断是否匹配的IoU阈值(默认值为0.5)。
        score_thr (float): 置信度阈值(低于该值的检测框将被过滤掉,默认值为0.0)。
        max_workers (int): 并发处理文件的最大线程数(默认值为30)。
    
    Returns:
        该函数将生成两个CSV文件:
        1. 'dota_grouped_results.csv':每个文件的Recall和Precision结果。
        2. 'dota_overall_results.csv':按类别计算的总体Recall和Precision。
    """
    
    # 统计目标类别的数量
    num_classes = len(class_map)
    
    # 初始化用于存储每个文件的Recall和Precision结果的数据框
    df_results = pd.DataFrame(columns=['Filename'] + [f'Recall_{c}' for c in class_map] + [f'Precision_{c}' for c in class_map])

    # 初始化总的TP、FP、FN的计数器
    total_tp = np.zeros(num_classes)
    total_fp = np.zeros(num_classes)
    total_fn = np.zeros(num_classes)

    # 定义处理每个文件的辅助函数
    def process_file(filename):
        """
        处理单个文件,计算该文件中每个类别的TP、FP、FN,并返回该文件的Recall和Precision。
        Args:
            filename: 需要处理的文件名。
        
        Returns:
            该文件的结果,包括:
            - 文件名
            - 每个类别的Recall、Precision
            - 每个类别的TP、FP、FN
        """
        # 获取当前文件的标注数据和检测结果
        gt_dota = [d for d in dataset if d[0] == filename]  # 获取ground truth(标注)数据
        result_dota = [d for d in results.get(filename, []) if d[2] >= score_thr]  # 过滤掉低于置信度阈值的检测结果

        # 将多边形标注数据转换为旋转框
        gt_bboxes_rotated = [polygon_to_rotated_box(d[1]) for d in gt_dota]
        gt_bboxes = torch.tensor(gt_bboxes_rotated, dtype=torch.float32)  # 转换为Tensor格式的ground truth边界框
        gt_labels = [class_map.index(d[2]) for d in gt_dota]  # 获取ground truth的类别标签

        # 将检测结果的边界框和类别标签也转换为Tensor
        result_bboxes = torch.tensor([d[0] for d in result_dota], dtype=torch.float32)
        result_labels = [class_map.index(d[1]) for d in result_dota]  # 获取检测结果的类别标签

        # 计算ground truth和检测结果之间的IoU
        ious = box_iou_rotated(gt_bboxes, result_bboxes)

        # 初始化当前文件每个类别的TP、FP、FN计数
        tp = np.zeros(num_classes)
        fp = np.zeros(num_classes)
        fn = np.zeros(num_classes)

        # 用于记录每个ground truth框是否被匹配
        matched = np.zeros(len(gt_labels), dtype=bool)

        # 遍历检测结果并与ground truth进行匹配
        for i, (result_bbox, result_label) in enumerate(zip(result_bboxes, result_labels)):
            overlaps = ious[:, i]  # 获取当前检测框与所有ground truth框的IoU
            matched_with_gt = False  # 标志检测框是否成功匹配ground truth框

            # 遍历每个ground truth框
            for j, (overlap, gt_label) in enumerate(zip(overlaps, gt_labels)):
                # 如果IoU超过阈值且类别一致,则该检测框为TP
                if overlap > tp_iou_thr and gt_label == result_label:
                    tp[result_label] += 1  # 对应类别的TP计数器+1
                    matched[j] = True  # 标记ground truth框已匹配
                    matched_with_gt = True  # 当前检测框找到匹配
                    break  # 找到匹配后停止查找

            # 如果没有找到匹配,则该检测框为FP
            if not matched_with_gt:
                fp[result_label] += 1

        # 统计没有匹配的ground truth框,记为FN
        for j, matched_flag in enumerate(matched):
            if not matched_flag:
                fn[gt_labels[j]] += 1

        # 计算每个类别的Recall和Precision
        recall = tp / (tp + fn + 1e-6)  # 防止除以0
        precision = tp / (tp + fp + 1e-6)  # 防止除以0
        
        # 文件名作为键值,返回结果
        filename_key = '_'.join(filename.split('_')[:-1])
        return filename_key, recall, precision, tp, fp, fn

    # 使用多线程处理文件,加快计算速度
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交每个文件的处理任务
        future_to_file = {executor.submit(process_file, filename): filename for filename in filenames}

        # 收集处理结果
        for future in as_completed(future_to_file):
            # 获取每个文件的结果
            filename_key, recall, precision, tp, fp, fn = future.result()

            # 将结果构造成DataFrame并更新到结果表中
            recall_dict = {f'Recall_{class_map[i]}': recall[i] for i in range(num_classes)}
            precision_dict = {f'Precision_{class_map[i]}': precision[i] for i in range(num_classes)}
            new_row = {'Filename': filename_key}
            new_row.update(recall_dict)
            new_row.update(precision_dict)
            df_results = pd.concat([df_results, pd.DataFrame([new_row])], ignore_index=True)

            # 更新总体的TP、FP、FN计数
            total_tp += tp
            total_fp += fp
            total_fn += fn

    # 计算总体的Recall和Precision
    overall_recall = total_tp / (total_tp + total_fn + 1e-6)  # 防止除以0
    overall_precision = total_tp / (total_tp + total_fp + 1e-6)  # 防止除以0

    # 打印每个类别的总体Recall和Precision
    for i, class_name in enumerate(class_map):
        print(f"Overall Recall for {class_name}: {overall_recall[i]:.4f}")
        print(f"Overall Precision for {class_name}: {overall_precision[i]:.4f}")

    # 将总体结果构造成DataFrame并保存为CSV文件
    overall_results = pd.DataFrame({
        'Class': class_map,
        'Overall Recall': overall_recall,
        'Overall Precision': overall_precision
    })

    # 保存每个文件的Recall和Precision结果到CSV
    df_results.to_csv('dota_grouped_results.csv', index=False)
    
    # 保存总体Recall和Precision结果到CSV
    overall_results.to_csv('dota_overall_results.csv', index=False)

def load_all_pkl_files(pkl_path):
    if pkl_path.endswith('.pkl'):
        file_results = mmcv.load(pkl_path)
    return file_results[0]

# 示例配置
class_map = ['small_vehicle', 'large_vehicle', 'plane', 'ship', 'harbor', 'baseball_diamond', 'tennis_court', 'basketball_court']

# 解析标注和模型结果
anno_folder = 'path_to_annotations'
model_folder = 'path_to_model_results'

filenames = []
anno_dota_list = []
model_dota_list = []

for anno_file in os.listdir(anno_folder):
    anno_path = os.path.join(anno_folder, anno_file)
    anno_dota = parse_dota_data(anno_path, class_map)
    
    model_path = os.path.join(model_folder, anno_file.replace('.txt', '.pkl'))
    model_dota = parse_rotated_model_dota(load_all_pkl_files(model_path), class_map)

    anno_dota_list.extend([(anno_file, *d) for d in anno_dota])
    model_dota_list.append((anno_file, model_dota))
    filenames.append(anno_file)

# 计算 recall 和 precision
results_dict = {fname: res for fname, res in model_dota_list}
calculate_recall_precision(class_map, anno_dota_list, results_dict, filenames, tp_iou_thr=0.5, score_thr=0.3)

函数说明

1. parse_rotated_model_dota(model_dotas, class_map_cleaned)

解析模型检测结果,提取每个类别的旋转框和对应的置信度得分。

  • 输入参数:
    • model_dotas: 模型预测结果,包含各类别的旋转框和置信度。
    • class_map_cleaned: 类别映射,定义类别名称。
  • 输出:
    • 每个类别的检测框及其置信度得分列表,格式为 [(coords, class_name, score), ...]

2. polygon_to_rotated_box(polygon)

将多边形坐标转换为旋转框(中心点、宽、高、旋转角度)。

  • 输入参数:
    • polygon: 8个顶点的多边形坐标。
  • 输出:
    • 返回5个参数的旋转框(中心点x、y,宽、高,旋转角度)。

3. parse_dota_data(file_path, class_map)

解析DOTA数据集标注文件,提取目标的多边形坐标和类别信息。

  • 输入参数:
    • file_path: 标注文件路径。
    • class_map: 类别映射表,定义合法的类别。
  • 输出:
    • 标注数据的列表,格式为 [(coords, class_name), ...]

4. calculate_recall_precision(class_map, dataset, results, filenames, tp_iou_thr=0.5, score_thr=0.0, max_workers=30)

计算DOTA数据集的 RecallPrecision,并生成评估结果。

  • 输入参数:
    • class_map: 目标类别名称列表。
    • dataset: 标注数据列表,格式为 [(filename, coords, class_name, difficulty), ...]
    • results: 模型检测结果,格式为 {filename: [(coords, class_name, score), ...]}
    • filenames: 文件名列表。
    • tp_iou_thr: IoU阈值,默认0.5,用于判定检测框是否匹配真实框。
    • score_thr: 置信度阈值,过滤掉低置信度的检测结果。
    • max_workers: 最大线程数,默认为30。
  • 输出:
    • 生成两个CSV文件:
      • dota_grouped_results.csv:每个文件的 RecallPrecision 结果。
      • dota_overall_results.csv:总体的 RecallPrecision 按类别汇总结果。

5. load_all_pkl_files(pkl_path)

加载并解析 .pkl 文件中的检测结果。

  • 输入参数:
    • pkl_path: .pkl 文件路径。
  • 输出:
    • 模型的检测结果数据。

引用库和链接

  1. NumPy

    • 用于数组和矩阵运算,处理检测框和多边形坐标的转换。
    • NumPy官网
  2. PyTorch

    • 用于张量处理和矩阵运算,在计算IoU时进行高效的张量操作。
    • PyTorch官网
  3. MMCV

    • 来自开源目标检测框架MMDetection,提供旋转框IoU计算和其他实用工具。
    • MMCV文档
  4. Pandas

    • 用于数据框处理和结果的保存,将 RecallPrecision 结果保存为CSV文件。
    • Pandas官网
  5. Concurrent Futures

  6. OpenCV

    • 用于图像处理和多边形最小外接矩形的计算。
    • OpenCV官网

结论

通过以上代码,您可以通用化地计算DOTA数据集中不同类型目标的召回率和精确率。该代码可以适应不同类型的目标检测任务,用户只需修改类别映射、标注数据和检测结果路径即可适应不同数据集。

---

希望这篇博客对你有所帮助,如果你喜欢这篇文章,请点赞或关注,我会持续分享更多实用的 目标检测 技术内容!

---

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值