sklearn.metrics计算模型metric

 下载包

python -m pip install scikit-learn  -i https://pypi.tuna.tsinghua.edu.cn/simple

 导入包

from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score

 模型评估

def evaluate(model, data_loader, device):
    model.eval()

    # 验证样本总个数
    total_num = len(data_loader.dataset)

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)
    # 初始化存储预测结果和真实标签的列表
    all_labels = []
    all_preds = []
    all_probs = []

    cont = 0
    outPre = []
    outLabel = []

    loss_function =torch.nn.CrossEntropyLoss().cuda()
    # loss_function.to(device)
    mean_loss = torch.zeros(1).to(device)

    data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))

        loss = loss_function(pred, labels.to(device))
        mean_loss = (mean_loss * step + loss) / (step + 1)  # update mean losses

        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()

        all_labels.extend(labels.cpu().numpy())
        preds = pred.cpu().numpy()
        all_preds.extend(preds)


        if cont == 0:
            outPre = pred.data.cpu()
            outLabel = labels.data.cpu()
        else:
            outPre = torch.cat((outPre, pred.data.cpu()), 0)
            outLabel = torch.cat((outLabel, labels.data.cpu()), 0)
        cont += 1

    y_pred = all_preds
    y_true = all_labels

    report = classification_report(y_true, y_pred, target_names=['Class 0', 'Class 1', 'Class 2'])
    print(report)

    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)

    print("Precision:\t\t", precision)
    print("Recall:\t\t", recall)
    print("F1 Score:\t\t", f1)

    # 计算微平均和宏平均的precision, recall, f1-score
    precision_micro = precision_score(y_true, y_pred, average='micro')
    recall_micro = recall_score(y_true, y_pred, average='micro')
    f1_micro = f1_score(y_true, y_pred, average='micro')

    precision_macro = precision_score(y_true, y_pred, average='macro')
    recall_macro = recall_score(y_true, y_pred, average='macro')
    f1_macro = f1_score(y_true, y_pred, average='macro')

    print("Micro-averaged Precision       Recall              \tF1 Score")
    print(f'{precision_micro}             {recall_micro}      \t{f1_micro}')
    print("Macro-averaged Precision\t\t\tRecall\t\t\tF1 Score")
    print(f'{precision_macro}\t\t\t{recall_macro}\t\t\t{f1_macro}')

    print("outPre:", outPre)
    print("outLabel", outLabel)
    print(f'cont = {cont},total_num = {total_num}')
    acc = sum_num.item() / total_num
    print('Loss:  {:.10f} Acc: {:.4f}'.format(mean_loss.item(),acc)) 


    return acc

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值