使用sklearn绘制混淆矩阵(confusion matrix)
最近因为项目需要绘制混淆矩阵,在网上找了很多资源,发现绘制出的混淆矩阵问题很多,有的很丑,有的显示不全。
为此我将最后可用的资源整合在这里方便自己查阅,也发便大家使用。
完整的代码详见Github。
如果帮助到你的话,请给我的Github项目一颗Star。
需要引用的库
from sklearn.metrics import confusion_matrix # 生成混淆矩阵函数
import matplotlib.pyplot as plt # 绘图库
import numpy as np
数据预处理
# cm: true_label, prediction_label 都是list
cm = confusion_matrix(true_label, prediction_label,)
# target_names: 类别的名称
target_names = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
绘制混淆矩阵的函数
def plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap=None,
normalize=True):
"""
given a sklearn confusion matrix (cm), make a nice plot
Arguments
---------
cm: confusion matrix from sklearn.metrics.confusion_matrix
target_names: given classification classes such as [0, 1, 2]
the class names, for example: ['high', 'medium', 'low']
title: the text to display at the top of the matrix
cmap: the gradient of the values displayed from matplotlib.pyplot.cm
see:
http://matplotlib.org/examples/color/colormaps_reference.html
plt.get_cmap('jet') or plt.cm.Blues
normalize: If False, plot the raw numbers
If True, plot the proportions
Usage
-----
plot_confusion_matrix(cm = cm,
normalize = True, # show proportions
target_names = y_labels_vals, # list of classes names
title = best_estimator_name) # title of graph
"""
import matplotlib.pyplot as plt
import numpy as np
import itertools
accuracy = np.trace(cm) / np.sum(cm).astype('float')
misclass = 1 - accuracy
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if normalize:
plt.text(j, i, "{:0.2f}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, "{:,}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.xlim(-0.5, len(target_names)-0.5)
plt.ylim(len(target_names)-0.5, -0.5)
#plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.savefig('result.jpg')
plt.show()
调用函数
plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap=None,
normalize=True)
结果
完整的代码详见Github。
如果帮助到你的话,请给我的Github项目一颗Star。