Python 混淆矩阵 可视化(热度图)

代码如下:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


def plotCM(classes, matrix, savname):
    """classes: a list of class names"""

    # Normalize by row
    matrix = matrix.astype(np.float)
    linesum = matrix.sum(1)
    linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
    matrix /= linesum

    # plot
    plt.switch_backend('agg')
    fig = plt.figure()

    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)

    ax.xaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(1))

    for i in range(matrix.shape[0]):
        ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')

    ax.set_xticklabels([''] + classes, rotation=90)
    ax.set_yticklabels([''] + classes)

    # save
    plt.savefig(savname)


if __name__ == '__main__':
    classes = ["A", "B", "C", "D", "E", "F", "G", "H"]

    matrix = np.array([[23, 1, 2, 52, 5, 0, 1, 0],
                       [0, 107, 2, 13, 18, 0, 2, 0],
                       [4, 23, 15, 6, 3, 0, 1, 0],
                       [12, 73, 1, 114, 0, 0, 0, 0],
                       [1, 0, 0, 0, 100, 0, 10, 0],
                       [0, 5, 0, 4, 7, 0, 0, 0],
                       [7, 10, 2, 31, 0, 0, 150, 0],
                       [0, 0, 0, 0, 0, 0, 5, 0]]
                     )
    savename = "test.png"

    plotCM(classes, matrix, savename)

结果图如下:

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值