【零基础保姆级教程】MMDetection3训练输出Precision/Recall/F1-Score指标

最近为了跑对比试验,MMDetection这一框架整合的算法较多,故博主训练它并留下记录,若有疑问等欢迎评论、指正。

基本信息:博主在完成训练流程后,保留了整个过程的权重文件在worke_dirs/路径下,名称epoch_1.pth-epoch_150epoch.pth。

给出公式原理:

当然可以。Precision(精确率)、Recall(召回率)和F1分数都是用来评估分类模型性能的重要指标,特别是在不平衡数据集的情况下。

### Precision(精确率)
精确率是指所有被预测为正类的样本中实际为正类的比例。公式如下:
\[ \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \]

其中:
- TP(True Positives):真正例,被模型正确地预测为正类的样本数。
- FP(False Positives):假正例,被模型错误地预测为正类的样本数。

### Recall(召回率)
召回率是指所有实际为正类的样本中被正确预测为正类的比例。公式如下:
\[ \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} \]

其中:
- FN(False Negatives):假负例,实际为正类但被模型错误地预测为负类的样本数。

### F1 Score(F1分数)
F1分数是精确率和召回率的调和平均值,它试图同时优化精确率和召回率。公式如下:
\[ F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} \]

当精确率和召回率相等时,F1分数取得最大值;当其中一个非常小而另一个较大时,F1分数会比较低。

一、输出单个轮次权重的指标:

输入命令:

python tools/test.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py work_dirs/epoch_xx.pth --out=result.pkl

此处命令的xx需改为实际的数字,即可对你训练出的权重进行测试,测试结果会输出COCO指标与一个result.pkl文件,这个文件可用于生成precision/recall/f1。

打开文件tools/analysis_tools/confusion_matrix.py

在文件后加入代码


    TP = np.diag(confusion_matrix)
    FP = np.sum(confusion_matrix, axis=0) - TP
    FN = np.sum(confusion_matrix, axis=1) - TP

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    average_precision = np.mean(precision)
    average_recall = np.mean(recall)
    f1 = 2 * (precision * recall) / (precision + recall)

    print('AP:', average_precision)
    print('AR:', average_recall)
    print('F1:', f1)
    print('Precision', precision)
    print('Recall', recall)


    output_file_path = os.path.join(save_dir, 'PRF1.txt')
    with open(output_file_path, 'a') as output_file:
        output_file.write({precision:.5f}   {recall:.5f}   {f1:.5f}\n')

运行命令

python tools/analysis_tools/confusion_matrix.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py result.pkl results/ --score-thr 0.5 

即可生成对应一个epoch权重的指标。

二、输出整个轮次权重的指标

对tools/test.py做修改,全文覆盖为博主代码

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from copy import deepcopy

from mmengine import ConfigDict
from mmengine.config import Config, DictAction
from mmengine.runner import Runner

from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.evaluation import DumpDetResults
from mmdet.registry import RUNNERS
from mmdet.utils import setup_cache_size_limit_of_dynamo


# TODO: support fuse_conv_bn and format_only
def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument(
        '--out',
        type=str,
        help='dump predictions to a pickle file for offline evaluation')
    parser.add_argument(
        '--show', action='store_true', help='show prediction results')
    parser.add_argument(
        '--show-dir',
        help='directory where painted images will be saved. '
        'If specified, it will be automatically saved '
        'to the work_dir/timestamp/show_dir')
    parser.add_argument(
        '--wait-time', type=float, default=2, help='the interval of show (s)')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--tta', action='store_true')
    # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
    # will pass the `--local-rank` parameter to `tools/train.py` instead
    # of `--local_rank`.
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args

def get_checkpoint_files(directory):
    """从指定目录中获取所有的.pth文件路径列表."""
    checkpoint_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.pth'):
                checkpoint_files.append(osp.join(root, file))
    return checkpoint_files

def main():
    # 默认参数
    config_path = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'  # 替换为你的配置文件路径
    checkpoint_directory = r'D:\mmdetection-main\work_dirs'
    work_dir = r'D:\mmdetection-main\work_dirs'  # 替换为你的.pkl文件生成目录路径

    # 获取.pth文件路径列表
    checkpoint_files = get_checkpoint_files(checkpoint_directory)

    for checkpoint_file in checkpoint_files:
        # 设置参数
        args = argparse.Namespace(
            config=config_path,
            checkpoint=checkpoint_file,
            work_dir=work_dir,
            out=None,
            show=False,
            show_dir=None,
            wait_time=2,
            cfg_options=None,
            launcher='none',
            tta=False,
            local_rank=0
        )

        # 生成输出文件名,与.pth同名但是是.pkl格式
        out_file = osp.splitext(checkpoint_file)[0] + '.pkl'
        args.out = out_file

        # Reduce the number of repeated compilations and improve
        # testing speed.
        setup_cache_size_limit_of_dynamo()

        # load config
        cfg = Config.fromfile(args.config)
        cfg.launcher = args.launcher
        if args.cfg_options is not None:
            cfg.merge_from_dict(args.cfg_options)

        # work_dir is determined in this priority: CLI > segment in file > filename
        if args.work_dir is not None:
            # update configs according to CLI args if args.work_dir is not None
            cfg.work_dir = args.work_dir
        elif cfg.get('work_dir', None) is None:
            # use config filename as default work_dir if cfg.work_dir is None
            cfg.work_dir = osp.join('./work_dirs',
                                    osp.splitext(osp.basename(args.config))[0])

        # 更新模型权重文件路径
        cfg.load_from = args.checkpoint

        if args.show or args.show_dir:
            cfg = trigger_visualization_hook(cfg, args)
        if args.tta:

            if 'tta_model' not in cfg:
                warnings.warn('Cannot find ``tta_model`` in config, '
                              'we will set it as default.')
                cfg.tta_model = dict(
                    type='DetTTAModel',
                    tta_cfg=dict(
                        nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
            if 'tta_pipeline' not in cfg:
                warnings.warn('Cannot find ``tta_pipeline`` in config, '
                              'we will set it as default.')
                test_data_cfg = cfg.test_dataloader.dataset
                while 'dataset' in test_data_cfg:
                    test_data_cfg = test_data_cfg['dataset']
                cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline)
                flip_tta = dict(
                    type='TestTimeAug',
                    transforms=[
                        [
                            dict(type='RandomFlip', prob=1.),
                            dict(type='RandomFlip', prob=0.)
                        ],
                        [
                            dict(
                                type='PackDetInputs',
                                meta_keys=('img_id', 'img_path', 'ori_shape',
                                           'img_shape', 'scale_factor', 'flip',
                                           'flip_direction'))
                        ],
                    ])
                cfg.tta_pipeline[-1] = flip_tta
            cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
            cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline

        # build the runner from config
        if 'runner_type' not in cfg:
            # build the default runner
            runner = Runner.from_cfg(cfg)
        else:
            # build customized runner from the registry
            # if 'runner_type' is set in the cfg
            runner = RUNNERS.build(cfg)
        # add `DumpResults` dummy metric
        if args.out is not None:
            assert args.out.endswith(('.pkl', '.pickle')), \
                'The dump file must be a pkl file.'
            runner.test_evaluator.metrics.append(
                DumpDetResults(out_file_path=args.out))
        # start testing
        runner.test()

if __name__ == '__main__':
    main()

直接运行该文件即可。

其中,博主将一些代码写死在main()函数中,使用时需修改,如下。

    # 默认参数
    config_path = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'  # 替换为你的配置文件路径
    checkpoint_directory = r'D:\mmdetection-main\work_dirs' # 替换为你的.pth文件存放目录路径
    work_dir = r'D:\mmdetection-main\work_dirs'  # 替换为你的.pkl文件生成目录路径

对tools\analysis_tools\confusion_matrix.py做修改,全文覆盖为博主代码

import argparse
import glob
import os

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mmcv.ops import nms
from mmengine import Config, DictAction
from mmengine.fileio import load
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar

from mmdet.evaluation import bbox_overlaps
from mmdet.registry import DATASETS
from mmdet.utils import replace_cfg_vals, update_data_root


def parse_args():
    # 这个函数不再需要,因为我们将直接在main函数中使用硬编码的参数
    pass

def calculate_confusion_matrix(dataset,
                               results,
                               score_thr=0,
                               nms_iou_thr=None,
                               tp_iou_thr=0.5):

    num_classes = len(dataset.metainfo['classes'])
    confusion_matrix = np.zeros(shape=[num_classes + 1, num_classes + 1])
    assert len(dataset) == len(results)
    prog_bar = ProgressBar(len(results))
    for idx, per_img_res in enumerate(results):
        res_bboxes = per_img_res['pred_instances']
        gts = dataset.get_data_info(idx)['instances']
        analyze_per_img_dets(confusion_matrix, gts, res_bboxes, score_thr,
                             tp_iou_thr, nms_iou_thr)
        prog_bar.update()
    return confusion_matrix


def analyze_per_img_dets(confusion_matrix,
                         gts,
                         result,
                         score_thr=0,
                         tp_iou_thr=0.5,
                         nms_iou_thr=None):

    true_positives = np.zeros(len(gts))
    gt_bboxes = []
    gt_labels = []
    for gt in gts:
        gt_bboxes.append(gt['bbox'])
        gt_labels.append(gt['bbox_label'])

    gt_bboxes = np.array(gt_bboxes)
    gt_labels = np.array(gt_labels)

    unique_label = np.unique(result['labels'].numpy())

    for det_label in unique_label:
        mask = (result['labels'] == det_label)
        det_bboxes = result['bboxes'][mask].numpy()
        det_scores = result['scores'][mask].numpy()

        if nms_iou_thr:
            det_bboxes, _ = nms(
                det_bboxes, det_scores, nms_iou_thr, score_threshold=score_thr)
        ious = bbox_overlaps(det_bboxes[:, :4], gt_bboxes)
        for i, score in enumerate(det_scores):
            det_match = 0
            if score >= score_thr:
                for j, gt_label in enumerate(gt_labels):
                    if ious[i, j] >= tp_iou_thr:
                        det_match += 1
                        if gt_label == det_label:
                            true_positives[j] += 1  # TP
                        confusion_matrix[gt_label, det_label] += 1
                if det_match == 0:  # BG FP
                    confusion_matrix[-1, det_label] += 1
    for num_tp, gt_label in zip(true_positives, gt_labels):
        if num_tp == 0:  # FN
            confusion_matrix[gt_label, -1] += 1


def plot_confusion_matrix(confusion_matrix,
                          labels,
                          save_dir=None,
                          show=True,
                          title='Normalized Confusion Matrix',
                          color_theme='plasma'):

    # normalize the confusion matrix
    per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]
    confusion_matrix = \
        confusion_matrix.astype(np.float32) / per_label_sums * 100

    num_classes = len(labels)
    fig, ax = plt.subplots(
        figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=180)
    cmap = plt.get_cmap(color_theme)
    im = ax.imshow(confusion_matrix, cmap=cmap)
    plt.colorbar(mappable=im, ax=ax)

    title_font = {'weight': 'bold', 'size': 12}
    ax.set_title(title, fontdict=title_font)
    label_font = {'size': 10}
    plt.ylabel('Ground Truth Label', fontdict=label_font)
    plt.xlabel('Prediction Label', fontdict=label_font)

    # draw locator
    xmajor_locator = MultipleLocator(1)
    xminor_locator = MultipleLocator(0.5)
    ax.xaxis.set_major_locator(xmajor_locator)
    ax.xaxis.set_minor_locator(xminor_locator)
    ymajor_locator = MultipleLocator(1)
    yminor_locator = MultipleLocator(0.5)
    ax.yaxis.set_major_locator(ymajor_locator)
    ax.yaxis.set_minor_locator(yminor_locator)

    # draw grid
    ax.grid(True, which='minor', linestyle='-')

    # draw label
    ax.set_xticks(np.arange(num_classes))
    ax.set_yticks(np.arange(num_classes))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    ax.tick_params(
        axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
    plt.setp(
        ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')

    # draw confution matrix value
    for i in range(num_classes):
        for j in range(num_classes):
            ax.text(
                j,
                i,
                '{}%'.format(
                    int(confusion_matrix[
                        i,
                        j]) if not np.isnan(confusion_matrix[i, j]) else -1),
                ha='center',
                va='center',
                color='w',
                size=7)

    ax.set_ylim(len(confusion_matrix) - 0.5, -0.5)  # matplotlib>3.1.1

    fig.tight_layout()
    if save_dir is not None:
        plt.savefig(
            os.path.join(save_dir, 'confusion_matrix.png'), format='png')
    if show:
        plt.show()

def main(config=None, prediction_path=None, save_dir=None, show=True, color_theme='plasma', score_thr=0.3, tp_iou_thr=0.5, nms_iou_thr=None, cfg_options=None):
    if config is None or prediction_path is None or save_dir is None:
        raise ValueError("config, prediction_path, and save_dir must be provided.")

    cfg = Config.fromfile(config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMYOLO_DATASETS
    update_data_root(cfg)

    if cfg_options is not None:
        cfg.merge_from_dict(cfg_options)

    init_default_scope(cfg.get('default_scope', 'mmdet'))

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    dataset = DATASETS.build(cfg.test_dataloader.dataset)

    results = load(prediction_path)

    confusion_matrix = calculate_confusion_matrix(dataset, results, score_thr, nms_iou_thr, tp_iou_thr)

    TP = np.diag(confusion_matrix)
    FP = np.sum(confusion_matrix, axis=0) - TP
    FN = np.sum(confusion_matrix, axis=1) - TP

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    average_precision = np.mean(precision)
    average_recall = np.mean(recall)
    f1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0])

    print('AP:', average_precision)
    print('AR:', average_recall)
    print('F1:', f1)
    print('Precision', precision[0])
    print('Recall', recall[0])

    #print('TP:', TP)
    #print('FP:', FP)
    #print('FN', FN)

    output_file_path = os.path.join(save_dir, 'PRF1.txt')
    with open(output_file_path, 'a') as output_file:
        output_file.write(f'{prediction_path}    {precision[0]:.5f}   {recall[0]:.5f}   {f1:.5f}\n')

if __name__ == '__main__':
    config = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'
    save_dir = r'D:\mmdetection-main\results'

    def numerical_sort(value):
        filename = os.path.basename(value)
        parts = filename.split('result')
        if len(parts) > 1:
            number_part = parts[1].split('.')[0]
            try:
                return int(number_part)
            except ValueError:
                return float('inf')
        else:
            return float('inf')

    # 获取预测结果文件夹下所有以 'result' 开头并按数字顺序排列的.pkl文件
    prediction_files = sorted(glob.glob(r'D:\lkx\mmdetection-jiexialaidouyongzhege\mmdetection-main\work_dirs\epoch_*.pkl'), key=numerical_sort)
    print(config)
    print(prediction_files)
    print(save_dir)
    for prediction_path in prediction_files:
        main(config=config, prediction_path=prediction_path, save_dir=save_dir)

其中,

    config = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py' # 修改为你的配置文件路径
    save_dir = r'D:\mmdetection-main\results' # 修改为你的指标存放路径

def main(config=None, prediction_path=None, save_dir=None, show=True, color_theme='plasma', score_thr=0.3, tp_iou_thr=0.5, nms_iou_thr=None, cfg_options=None):

如上。这里的参数需要手动设置。

该文件也是直接运行即可。

三、将输出的.txt指标文件转换为xml格式
import pandas as pd

# 读取文本文件
with open('PRF1.txt', 'r') as file:
    lines = file.readlines()

# 处理每一行数据
data = []
for line in lines:
    line = line.strip()
    if line:
        row = line.split()

        data.append(row)

# 创建DataFrame对象
df = pd.DataFrame(data,
                  columns=['epoch', 'Precision', 'Recall', 'F1'])

# 保存为Excel文件
df.to_excel('PRF1.xlsx', index=False)

更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值