【深度学习】多分类任务评估指标sklearn和torchmetrics对比
说明
sklearn和torchmetrics两个metric代码跑模型的输出结果一致,对比他们的区别。评估指标写在下面
sklearn代码
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
class MultiClassReport():
"""
Accuracy, F1 Score, Precision and Recall for multi - class classification task.
"""
def __init__(self, name='MultiClassReport', average='macro'):
super(MultiClassReport, self).__init__()
self.average = average
self._name = name
self.reset()
def reset(self):
"""
Resets all the metric state.
"""
self.y_prob = []
self.y_true = []
def update(self, probs, labels):
# 将Tensor转换为numpy数组并添加到相应列表中
if isinstance(probs, torch.Tensor):
if probs.requires_grad:
probs = probs.detach()
probs = probs.cpu().numpy()
if isinstance(labels, torch.Tensor):
if labels.requires_grad:
labels = labels.detach()
labels = labels.cpu()