python绘制混淆矩阵

test_generator = test_datagen.flow_from_directory(
        'dataset/test',
        target_size=(48, 48),
        shuffle = False ,#随机打乱默认为true
        # batch_size=16,
        color_mode="grayscale",
        class_mode = 'categorical')
predictions = model.predict_generator(test_generator)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_generator.classes
labels = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
confusion_matrix = confusion_matrix(true_classes, predicted_classes)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator


def plotCM(classes, matrix, savname):
    """classes: a list of class names"""
    # Normalize by row
    matrix = matrix.astype(np.float)
    # linesum = matrix.sum(1)
    # linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
    # matrix /= linesum
    # plot
    plt.switch_backend('agg')
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)
    ax.xaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(1))
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            ax.text(j, i, str('%.2f' % (matrix[i, j])), va='center', ha='center')
    ax.set_xticklabels([''] + classes, rotation=90)
    ax.set_yticklabels([''] + classes)
    # save
    plt.savefig(savname)
plotCM(labels, confusion_matrix, 'matrix.jpg')

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值