分类相关导航:
【机器学习】分类任务以mnist为例,数据集准备及预处理
【机器学习】scikitLearn分类任务以mnist为例,训练二分类器并衡量性能指标:ROC及PR曲线
【机器学习】scikitLearn分类:鉴别二分类、多分类、多标签及多输出的分类任务
1.绘制混淆矩阵:
#1.建立模型
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
sgd_clf.fit(X_train, y_train_5)
#2.绘制混淆矩阵
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
#3.给定绘图用函数:
def plot_confusion_matrix(matrix):
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
#matplotlib.pyplot.matshow()函数用于在新图形窗口中将数组表示为矩阵
cax = ax.matshow(matrix)
fig.colorbar(cax)
#4.调用函数进行绘图
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
图中颜色越白的区域,图片越多。对角线上是被正确分类的图片。
2.更改矩阵绘制方式,将注意力焦点集中在错误分类的值:
#求出每行的和
row_sums = conf_mx.sum(axis=1, keepdims=True)
#算出每行每格所占的比例,行代表真实值,这个比例代表真值对应的不同预测值
norm_conf_mx = conf_mx / row_sums
#将矩阵对角线,理论上最多的值全部置为0,增加其余位置色块的对比度
np.fill_diagonal(norm_conf_mx, 0)
#使用matshow进行矩阵展示
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
如图所示,越亮的色块错误越多,如真值5被错分为真值8的情况最多,可通过对图像进行预处理,或对5及8样本数增广的形式,提高对应分类的正确率。