混淆矩阵的代码实现

import os
import json

import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable

import creatdataset
import resnext

class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库,能够将输出打印成列表的形式
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        # 创造一个shape为num_classes*num_classes的正方形混淆矩阵,且初始化为0
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        # 将预测的值和输入标签输入进来
        for p, t in zip(preds, labels):
           # p代表预测值,t代表真实标签
            self.matrix[p, t] += 1

    def summary(self):
        # 精度/准确率
        sum_TP = 0
        for i in range(self
  • 3
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值