import os
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
import creatdataset
import resnext
class ConfusionMatrix(object):
"""
注意,如果显示的图像不全,是matplotlib版本问题
本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
需要额外安装prettytable库,能够将输出打印成列表的形式
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
# 创造一个shape为num_classes*num_classes的正方形混淆矩阵,且初始化为0
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
# 将预测的值和输入标签输入进来
for p, t in zip(preds, labels):
# p代表预测值,t代表真实标签
self.matrix[p, t] += 1
def summary(self):
# 精度/准确率
sum_TP = 0
for i in range(self
混淆矩阵的代码实现
于 2022-03-24 22:47:15 首次发布