绘制混淆矩阵时候的normalize的作用

normalize的使用有何影响?

normalize 参数用于控制混淆矩阵是否进行归一化。混淆矩阵是用于评估分类模型性能的表格,其中行表示实际类别,列表示预测类别。在可视化混淆矩阵时,有时候希望将每行的值归一化,以便更清晰地了解模型在每个类别上的性能,而不受类别样本数量的影响。

具体来说:

  • normalize=False 时,混淆矩阵中的值表示每个类别的样本数量,而不进行归一化。

  • normalize=True 时,混淆矩阵中的值被归一化为每个类别的样本百分比,即每行的和变为 1。这样可以更容易地比较不同类别之间的性能,而不受类别样本数量的差异影响。

在可视化混淆矩阵时,通常使用归一化的混淆矩阵,因为这样更容易识别模型在各个类别上的分类准确度。通过比较每个类别的归一化值,您可以更好地了解模型在不同类别上的性能表现。

学习测试代码

# 绘制混淆矩阵的图片
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='混淆矩阵',
                          cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    # print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=310, size=12)
    plt.yticks(tick_marks, classes, size=12)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white"
                 if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('真实标签', size=12)
        plt.xlabel('预测标签', size=12)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是绘制混淆矩阵的PyTorch代码示例: ```python import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if not title: if normalize: title = 'Normalized confusion matrix' else: title = 'Confusion matrix, without normalization' # Compute confusion matrix cm = confusion_matrix(y_true, y_pred) # Only use the labels that appear in the data classes = classes[unique_labels(y_true, y_pred)] if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=cmap) ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, title=title, ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig.tight_layout() return ax ``` 这段代码使用Matplotlib库和Scikit-learn库中的`confusion_matrix`函数来绘制混淆矩阵。在函数中,你需要提供真实标签`y_true`和预测标签`y_pred`,以及类别列表`classes`,它包含了你的模型预测的所有类别。你可以通过设置`normalize=True`来获得归一化的混淆矩阵。最后,使用`plt.show()`函数来显示混淆矩阵

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王摇摆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值