在进行处理分类问题,常常需要画混淆矩阵对数据分类情况进行分析,这里安利一个混淆矩阵的方法:
1.首先导入要用到的包:
import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
2.定义混淆矩阵函数,进行相关参数设置:
def plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap='Blues',#这个地方设置混淆矩阵的颜色主题,这个主题看着就干净~
normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure(figsize=(9, 7))
# plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title