多分类问题中的混淆矩阵
#十分类问题
import matplotlib.pyplot as plt
import numpy as np
数据
from sklearn import datasets
digits = datasets.load_digits()
X = digits.data
y = digits.target
分割数据集
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.8)
使用OvR方式解决分类问题
#默认使用OvR方式解决分类问题
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
0.93115438108484
y_predict = log_reg.predict(X_test)
默认求解二分类问题,如果需要多分类问题求解则需要调整参数average='micro‘
#precision_score默认求解二分类问题,如果需要多分类问题求解则需要调整参数
from sklearn.metrics import precision_score
precision_score(y_test,y_predict,average='micro')
#recall 和 f1_score也有类似参数
0.93115438108484
混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test,y_predict)
array([[139, 0, 0, 0, 1, 0, 2, 0, 0, 0],
[ 0, 138, 0, 3, 0, 0, 3, 1, 2, 7],
[ 0, 2, 134, 6, 0, 0, 1, 0, 5, 0],
[ 0, 2, 2, 128, 0, 2, 0, 4, 1, 2],
[ 0, 1, 0, 0, 130, 0, 0, 8, 2, 1],
[ 0, 2, 0, 0, 1, 138, 1, 1, 1, 4],
[ 0, 1, 0, 0, 0, 0, 140, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 148, 2, 0],
[ 0, 5, 1, 0, 0, 3, 1, 0, 117, 3],
[ 0, 3, 0, 1, 0, 3, 0, 7, 1, 127]], dtype=int64)
绘制一下混淆矩阵
#绘制一下混淆矩阵
cfm = confusion_matrix(y_test,y_predict)
plt.matshow(cfm,cmap=plt.cm.gray)
继续处理,我们关注的是误判的
#继续处理
row_sums = np.sum(cfm,axis=1)#行方向求和
erro_matrix = cfm / row_sums # 计算错误比例
#把对角线数据(预测正确的设置为0)
np.fill_diagonal(erro_matrix,0)
erro_matrix
array([[0. , 0. , 0. , 0. , 0.00704225,
0. , 0.0141844 , 0. , 0. , 0. ],
[0. , 0. , 0. , 0.0212766 , 0. ,
0. , 0.0212766 , 0.00666667, 0.01538462, 0.04929577],
[0. , 0.01298701, 0. , 0.04255319, 0. ,
0. , 0.0070922 , 0. , 0.03846154, 0. ],
[0. , 0.01298701, 0.01351351, 0. , 0. ,
0.01351351, 0. , 0.02666667, 0.00769231, 0.01408451],
[0. , 0.00649351, 0. , 0. , 0. ,
0. , 0. , 0.05333333, 0.01538462, 0.00704225],
[0. , 0.01298701, 0. , 0. , 0.00704225,
0. , 0.0070922 , 0.00666667, 0.00769231, 0.02816901],
[0. , 0.00649351, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.01538462, 0. ],
[0. , 0.03246753, 0.00675676, 0. , 0. ,
0.02027027, 0.0070922 , 0. , 0. , 0.02112676],
[0. , 0.01948052, 0. , 0.0070922 , 0. ,
0.02027027, 0. , 0.04666667, 0.00769231, 0. ]])
越亮犯错越多
把真值1预测成了9…
#越亮犯错越多
#把真值1预测成了9..
plt.matshow(erro_matrix,cmap=plt.cm.gray)
#之后就可以微调分类器。。