Pytorch训练模型得到输出后计算F1-Score 和AUC

1、计算F1-Score

对于二分类来说,假设batch size 大小为64的话,那么模型一个batch的输出应该是torch.size([64,2]),所以首先做的是得到这个二维矩阵的每一行的最大索引值,然后添加到一个列表中,同时把标签也添加到一个列表中,最后使用sklearn中计算F1的工具包进行计算,代码如下

import numpy as np
import sklearn.metrics import f1_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob = prob.cpu().numpy() #先把prob转到CPU上,然后再转成numpy,如果本身在CPU上训练的话就不用先转成CPU了
    prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
    label_all.extend(label)
print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))

2、计算AUC

计算AUC的时候,本次使用的是sklearn中的roc_auc_score () 方法

输入参数:

y_true:真实的标签。形状 (n_samples,) 或 (n_samples, n_classes)。二分类的形状 (n_samples,1),而多标签情况的形状 (n_samples, n_classes)。

y_score:目标分数。形状 (n_samples,) 或 (n_samples, n_classes)。二分类情况形状 (n_samples,1),“分数必须是具有较大标签的类的分数”,通俗点理解:模型打分的第二列。举个例子:模型输出的得分是一个数组 [0.98361117 0.01638886],索引是其类别,这里 “较大标签类的分数”,指的是索引为 1 的分数:0.01638886,也就是正例的预测得分。

average='macro':二分类时,该参数可以忽略。用于多分类,' micro ':将标签指标矩阵的每个元素看作一个标签,计算全局的指标。' macro ':计算每个标签的指标,并找到它们的未加权平均值。这并没有考虑标签的不平衡。' weighted ':计算每个标签的指标,并找到它们的平均值,根据支持度 (每个标签的真实实例的数量) 进行加权。

sample_weight=None:样本权重。形状 (n_samples,),默认 = 无。

max_fpr=None:

multi_class='raise':(多分类的问题在下一篇文章中解释)

labels=None:

输出:

auc:是一个 float 的值。
 

import numpy as np
import sklearn.metrics import roc_auc_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的数,根据该函数的参数可知,y_score表示的较大标签类的分数,因此就是最大索引对应的那个值,而不是最大索引值
    label_all.extend(label)
print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))

参考:https://blog.csdn.net/pearl8899/article/details/109829306

https://blog.csdn.net/iamjingong/article/details/86636934

  • 8
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
### 回答1: 以下是使用PyTorch计算F1-score的代码示例: ```python import torch def f1_score(y_true, y_pred): """ 计算F1-score :param y_true: 真实标签 :param y_pred: 预测标签 :return: F1-score """ y_true = torch.Tensor(y_true) y_pred = torch.Tensor(y_pred) tp = torch.sum(y_true * y_pred) fp = torch.sum((1 - y_true) * y_pred) fn = torch.sum(y_true * (1 - y_pred)) precision = tp / (tp + fp + 1e-10) recall = tp / (tp + fn + 1e-10) f1 = 2 * precision * recall / (precision + recall + 1e-10) return f1.item() ``` 其中,y_true和y_pred分别是真实标签和预测标签,可以是任意形状的张量。在函数中,我们首先将它们转换为PyTorch张量。然后,我们计算真正例(tp)、假正例(fp)和假负例(fn)的数量,以便计算精度(precision)、召回率(recall)和F1-score。最后,我们使用.item()方法将F1-score从张量中提取出来并返回。 ### 回答2: F1-score是评估分类模型效果的指标之一,它综合了模型的精确度和召回率。在PyTorch中,计算F1-score需要从以下两个方面进行考虑:模型的预测和真实标签的比对,以及对预测结果的阈值的调整。 首先,模型的预测和真实标签的比对是F1-score计算的基础。这可以通过使用PyTorch的torch.argmax函数来实现。假设模型的输出是一个张量,包含每个样本的概率分布。可以使用torch.argmax函数获取每个样本的预测类别,然后和真实标签进行比较,得到分类的预测结果,在此基础上进一步计算模型的精确度和召回率。 接下来,对预测结果的阈值的调整是提高F1-score的关键,因为不同的阈值会对模型的表现产生不同的影响。一般来说,当模型对某个类别的预测概率大于某个阈值时,就会将该样本归为该类别。这个阈值可以通过调整模型的决策边界或者设置一个额外的参数来实现。在PyTorch中,可以通过自定义一个损失函数来实现F1-score计算,该损失函数可以设置一个阈值参数,并根据该参数进行阈值的调整,使得F1-score最大化。 总之,计算F1-score需要从模型的预测和真实标签的比对以及对预测结果的阈值的调整两个方面进行考虑。在PyTorch中,可以通过使用torch.argmax函数获取每个样本的预测类别,并自定义一个损失函数以实现对预测结果的阈值的调整,从而计算出模型的F1-score指标。 ### 回答3: F1 分数是评估二分类模型性能的一种常用指标。它同时考虑了模型的查准率 (Precision) 和查全率 (Recall),具有很好的平衡性,因此可以更全面地反映模型的性能。在 PyTorch 中,我们可以通过以下代码计算 F1 分数: ```python from sklearn.metrics import f1_score # 预测值和真实值分别存储在 y_pred 和 y_true 中 y_pred = model(x_test) y_pred = (y_pred > 0.5).float() # 将概率值转换成二分类标签 y_pred = y_pred.cpu().numpy().squeeze() y_true = y_test.cpu().numpy().squeeze() # 计算 F1 分数 f1 = f1_score(y_true, y_pred) ``` 在上述代码中,首先需要导入 `sklearn.metrics` 库中的 f1_score 函数。接下来,我们需要将模型的预测值和真实值分别存储在 y_pred 和 y_true 中。在二分类任务中,我们通常需要将模型的输出概率值转换成二分类标签。因此,我们使用 `(y_pred > 0.5).float()` 将概率值大于 0.5 的标记为 1,否则标记为 0。最后,我们将 y_pred 和 y_true 转换成 numpy 数组,并使用 f1_score 函数计算 F1 分数。 需要注意的是,不同的 F1 分数实现方式可能会存在一些差异,因此在使用时需要根据具体应用场景进行选择。在实际应用中,我们还需要注意评估指标的合理性和模型的泛化能力,避免模型过拟合或欠拟合等问题,提高模型的稳定性和可靠性。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值