如果不想看前面的介绍可以直接点击目录转到“代码实现”与“参数讲解”部分
混淆矩阵简介
混淆矩阵(confusion matrix)是一种常用的分类模型评估指标,经可视化的混淆矩阵可以帮助我们直观地了解到模型对各类样本地分类情况。
对于一个分类模型来说,每一个样本均有两种属性,真实标签与预测标签,记为
y
r
e
a
l
,
y
p
r
e
d
i
c
t
y_{real}, y_{predict}
yreal,ypredict。由此出发我们可以定义混淆矩阵:
对于
n
n
n分类模型,混淆矩阵为一
n
×
n
n\times n
n×n阶方阵:
M
c
=
[
m
i
,
j
]
n
×
n
M_c=[m_{i,j}]_{n\times n}
Mc=[mi,j]n×n
其第
i
i
i行第
j
j
j列元素定义为:
m
i
,
j
=
n
y
r
e
a
l
=
i
,
y
p
r
e
d
i
c
t
=
j
n
y
r
e
a
l
=
i
m_{i, j}=\frac{n_{y_{real=i}, y_{predict=j}}} {n_{y_{real=i}}}
mi,j=nyreal=inyreal=i,ypredict=j
即混淆矩阵的第
i
i
i行第
j
j
j列为真实标签为
i
i
i的样本被预测为
j
j
j类的比例。
基于python matplotlib/sklearn库的混淆矩阵代码实现
由于并没有找到现成的绘制混淆矩阵的函数,因此基于matplotlib
与sklearn
库的相关函数自己编写实现混淆矩阵的绘制
效果示例
利用39节点电网仿真数据进行,数据共包含5000个样本,每个样本分为不稳定、稳定、潮流不收敛三类,在数据集中分别用0, 1, 2表示。选取支持向量机作为分类模型。
利用4000个样本组成训练集训练模型,再利用1000个样本组成的测试机对模型性能进行测试,并利用plot_matrix(y_true, y_pred, labels_name)
函数将结果绘制为混淆矩阵。
示例代码如下:
# 读取数据并划分训练集与测试集
test_features, test_labels = read_data(ADDRESS)
feature_train, feature_test, label_train, label_test = train_test_split(test_features, test_labels,
test_size=0.2, random_state=0)
# 利用训练集训练支持向量机模型
svc = SVC(kernel='rbf')
svc.fit(feature_train, label_train)
# 利用训练好的模型对测试集进行分类
label_test_svc = svc.predict(feature_test)
# 根据真实标签与预测标签绘制混淆矩阵
plot_matrix(label_test, label_test_svc, [0, 1, 2], title='confusion_matrix_svc',
axis_labels=['unstable', 'stable', 'non-convergence'])
最终结果如下所示:
代码实现
plot_matrix(y_true, y_pred, labels_name)
函数如下所示:
import matplotlib.pyplot as pl
from sklearn import metrics
# 相关库
def plot_matrix(y_true, y_pred, labels_name, title=None, thresh=0.8, axis_labels=None):
# 利用sklearn中的函数生成混淆矩阵并归一化
cm = metrics.confusion_matrix(y_true, y_pred, labels=labels_name, sample_weight=None) # 生成混淆矩阵
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # 归一化
# 画图,如果希望改变颜色风格,可以改变此部分的cmap=pl.get_cmap('Blues')处
pl.imshow(cm, interpolation='nearest', cmap=pl.get_cmap('Blues'))
pl.colorbar() # 绘制图例
# 图像标题
if title is not None:
pl.title(title)
# 绘制坐标
num_local = np.array(range(len(labels_name)))
if axis_labels is None:
axis_labels = labels_name
pl.xticks(num_local, axis_labels, rotation=45) # 将标签印在x轴坐标上, 并倾斜45度
pl.yticks(num_local, axis_labels) # 将标签印在y轴坐标上
pl.ylabel('True label')
pl.xlabel('Predicted label')
# 将百分比打印在相应的格子内,大于thresh的用白字,小于的用黑字
for i in range(np.shape(cm)[0]):
for j in range(np.shape(cm)[1]):
if int(cm[i][j] * 100 + 0.5) > 0:
pl.text(j, i, format(int(cm[i][j] * 100 + 0.5), 'd') + '%',
ha="center", va="center",
color="white" if cm[i][j] > thresh else "black") # 如果要更改颜色风格,需要同时更改此行
# 显示
pl.show()
参数讲解
plot_matrix
函数参数包括:
y_true
样本的真实标签,为一向量
y_pred
样本的预测标签,为一向量,与真实标签长度相等
labels_name
样本在数据集中的标签名,如在示例中,样本的标签用0, 1, 2表示,则此处应为[0, 1, 2]
title=None
图片的标题
thresh=0.8
临界值,大于此值则图片上相应位置百分比为白色
axis_labels=None
最终图片中显示的标签名,如在示例中,样本标签用0, 1, 2表示分别表示失稳、稳定与潮流不收敛,我们最终图片中显示后者而非前者,则可令此参数为[‘unstable’, ‘stable’, ‘non-convergence’]