一、混淆矩阵的定义
混淆矩阵也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。具体评价指标有总体精度、制图精度、用户精度等,这些精度指标从不同的侧面反映了图像分类的精度。
在人工智能中,混淆矩阵(confusion matrix)是可视化工具,特别用于监督学习,在无监督学习中一般叫做匹配矩阵。
在图像精度评价中,主要用于比较分类结果和实际测得值,可以把分类结果的精度显示在一个混淆矩阵里面。混淆矩阵是通过将每个实测像元的位置和分类与分类图像中的相应位置和分类相比较计算的。
二、混淆矩阵表达的含义
1.混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;
2.每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目;每一列中的数值表示真实数据被预测为该类的数目。
三、混淆矩阵的表现方式
混淆矩阵 | 真实值 | ||
Positive | Negative | ||
预测值 | Positive | TP | FP |
Negative | FN | TN |
- 真实值=Positive,预测值=Positive (TP=True Positive)
(真阳性)样本的真实类别是正类,并且模型识别的结果也是正类。
- 真实值=Positive,预测值=Negative (FN=False Negative)
(假阳性)样本的真实类别是正类,但是模型将其识别为负类。
- 真实值=Negative,预测值=Positive (FP=True Positive)
(假阴性)样本的真实类别是负类,但是模型将其识别为正类。
- 真实值=Negative,预测值=Negative (TN=True Positive)
(真阴性)样本的真实类别是负类,并且模型将其识别为负类。
四、混淆矩阵的指标
(1)基本的统计结果
预测性分类模型,目的是预测结果的准确率越高越好。对应到混淆矩阵中,表现得方式就是TP(真阳性)与TN(真阴性)的数量大,而FP(假阳性)与FN(假阴性)的数量小。
(2)二级指标
在混淆矩阵里面统计的是个数,在大量数据前,TP、FN、TP、TN的个数不是能很好的体现出模型预测结果的好坏,从而无法对比模型的优劣程度。因此,混淆矩阵在基本结果上延伸出以下4个指标,又被称为二级指标。它们计算公式如下:
主要代码
import os
import math
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
class ConfusionMatrix(object):
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, specificity
table = PrettyTable()
table.field_names = ["", "Precision", "Recall", "Specificity"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
# 小数点取后三位
table.add_row([self.labels[i], Precision, Recall, Specificity])
print(table)
def plot(self):
# 绘制混淆矩阵
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
# 设置x轴坐标label
plt.xticks(range(self.num_classes), self.labels, rotation=45)
# 设置y轴坐标label
plt.yticks(range(self.num_classes), self.labels)
# 右侧显示色谱
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()