用sklearn计算roc、acc、precision、recall、f1——首次用python的类——求取output的pre

仅作为记录,大佬请跳过。

sklearn求取指标的全部代码

from sklearn.metrics import roc_curve, auc, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import numpy as np

class Score:
    def __init__(self, y_output, y_label, y_pre):
        self.y_output = y_output
        self.y_label = y_label
        self.y_pre = y_pre
    def cal_roc(self):
        cls = len(self.y_output[0])
        for i in range(cls):
            fpr, tpr, _ = roc_curve(self.y_label, [self.y_output[j][i] for j in range(len(self.y_output))], pos_label=i)
            roc_auc = auc(fpr, tpr)
            print('ok')
            return fpr, tpr, roc_auc        # 只进行一次roc
    def cal_acc(self):
        return accuracy_score(self.y_label, self.y_pre)
    def cal_precision(self):
        return precision_score(self.y_label, self.y_pre)
    def cal_recall(self):
        return recall_score(self.y_label, self.y_pre)
    def cal_f1(self):
        return f1_score(self.y_label, self.y_pre)

网络训练后对output的处理

在这里插入图片描述
全部代码

import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils
import torch.nn as nn

@torch.no_grad()
def final_test(data_loader, model, device):
    save_target=[]
    save_output=[]
    save_pre=[]
    m = nn.Softmax(dim=1)

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)

        # 保存待计算的数据
        save_output.extend(list((m(output)).cpu().numpy()))
        save_target.extend(list(target.cpu().numpy()))
        save_pre.extend(list(torch.max(m(output),dim=1).indices.cpu().numpy()))

    print('ok')
    return  save_output,save_target,save_pre

main函数中获得output的处理 和 使用计算roc的类

在这里插入图片描述

    # TODO: 加入test的指标计算
    if args.final_test:
        checkpoint_dir = os.path.join(args.output_dir, args.final_test_best_checkpoint)
        checkpoint = torch.load(checkpoint_dir)
        model.load_state_dict(checkpoint['model'])
        ft_output,ft_target,ft_pre = final_test(data_loader_test, model, device)
        score = Score(ft_output,ft_target,ft_pre)
        acc = score.cal_acc()
        precision = score.cal_precision()
        recall = score.cal_recall()
        f1 = score.cal_f1()
        fpr, tpr, roc_auc = score.cal_roc()
        model_name = os.path.basename(checkpoint_dir)

        csv_path = os.path.join(args.output_dir, 'score.csv')
        if not os.path.exists(csv_path):
            with open(csv_path, 'w', encoding='utf-8') as f:
                writer = csv.writer(f, lineterminator='\n')
                writer.writerow(['model', 'acc', 'precision', 'recall', 'f1', 'roc'])  # 【】
                writer.writerow([model_name, acc, precision, recall, f1, roc_auc])  # 【】
                print('文件不存在')

        df = pd.read_csv(csv_path, encoding="utf-8")
        if model_name not in df.values:
            with open(csv_path, 'a+', encoding='utf-8') as f:
                writer = csv.writer(f, lineterminator='\n')
                writer.writerow([model_name, acc, precision, recall, f1, roc_auc])  # 【】

        return

文件创建展示

main.py中获得output的处理 和 使用计算roc的类,同时

from calculate_roc_f1 import Score

sklearn求取指标的全部代码在calculate_roc_f1.py

网络训练后对output的处理在engine.py

在这里插入图片描述

参考

感谢大佬博主文章:python的sklearn计算roc、acc、precision、recall、f1

在这里插入图片描述

python的topk

求取output种预测的类别

_,pre=outputs.topk(1, dim=1, largest=True)

参考

传送门

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值