图像分类中混淆矩阵的绘制代码

import os
import json

import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, FixedFormatter
from prettytable import PrettyTable

from model import MobileNetV2
from resnet import Mymodel


class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        # 对每个类别的预测数量应用softmax函数
        row_sums = matrix.sum(axis=0, keepdims=True)
        probabilities = matrix / row_sums
        plt.figure(figsize=(20, 14))
        plt.imshow(probabilities, cmap=plt.cm.Blues)
        # plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体
        # 指定中文字体为仿宋
        # plt.rcParams['font.family'] = 'FangSong'

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), range(self.num_classes), fontdict={'fontname': 'Times New Roman', 'fontsize':28})
        # 自定义x轴标签的绘制方式
        # ax = plt.gca()
        # ax.set_xticks(range(self.num_classes))
        # ax.set_xticklabels(self.labels, fontdict={'fontname': 'Times New Roman', 'fontsize': 10})

        # # 调整x轴标签的位置和对齐方式
        # plt.xticks(rotation=45, ha='right', fontsize=10, horizontalalignment='right')
        
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels,fontdict={'fontname': 'Times New Roman', 'fontsize': 28})
        # 显示colorbar
        # plt.colorbar()
        # plt.xlabel('真实标签',fontdict={'fontsize': 30})
        # plt.ylabel('预测标签',fontdict={'fontsize': 30})
        plt.xlabel('True Label',fontdict={'fontname': 'Times New Roman','fontsize': 30})
        plt.ylabel('Predicted Label',fontdict={'fontname': 'Times New Roman','fontsize': 30})
        # plt.title('Confusion matrix',fontdict={'fontname': 'Times New Roman', 'fontsize': 15})

        # 在图中标注数量/概率信息
        # thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                # info = int(matrix[y, x])
                prob = probabilities[y,x]
                if prob >= 0.01:
                    plt.text(x, y, f'{prob:.2f}',
                            verticalalignment='center',
                            horizontalalignment='center',
                            fontsize = 14,
                            color="white" if prob > 0.5 else "black")
                    
        # 调整图像边距
        # plt.subplots_adjust(bottom=0.2, left=0.2)
        plt.tight_layout()

        # 控制输出图像的分辨率并保存到指定路径
        plt.savefig('C:/Users/86133/Desktop/IRCHKD混淆矩阵/AID_50_4.png',dpi=600)

        plt.show()


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    data_transform = transforms.Compose([transforms.Resize((448,448)),
                                        #  transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "D:/Users/Code/FRSKD-main/DataSet/AID")  # flower data set path
    assert os.path.exists(image_path), "data path {} does not exist.".format(image_path)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform)

    batch_size = 8
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)
    # net = MobileNetV2(num_classes=5)
    net = Mymodel(num_classes=30)
    # load pretrain weights
    model_weight_path = "D:/Users/Code/WZ/pytorch_classification/ConfusionMatrix/AID_xiaorong_kd_97.84.pth"
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)

    # read class_indict
    json_label_path = './class_indices.json'
    assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
    json_file = open(json_label_path, 'r')
    class_indict = json.load(json_file)

    labels = [label for _, label in class_indict.items()]
    print('labels:',labels)
    confusion = ConfusionMatrix(num_classes=30, labels=labels)
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device),is_test=True)
            # print('outputs.shape:',outputs.shape)
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()
    confusion.summary()

食用方法:将代码中的Mymodel换成自己的模型,model_weight_path换成自己训练好的权重的路径,还要创建一个包含类别标签的json文件

# read class_indict
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file = open(json_label_path, 'r')
class_indict = json.load(json_file)

labels = [label for _, label in class_indict.items()]

模型的类别数和ConfusionMatrix的类别数要对应。做好这些就可以一键运行了。之后可以根据自己类别数的多少调整下字体和画布的大小。

  • 11
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值