基于sklearn的常用分类任务指标Python实现
一、摘要
分类任务常用指标包含混淆矩阵、每类分类精度、平均分类精度、总体分类精度、f1-score等。Python的sklearn.metrics 模块覆盖了分类任务中大部分常用的验证指标,本文选择其中几种评价指标展示代码片段,供读者使用。基于tensorflow-1.0与mnist数据集做demo展示并列举实验结果。文末附有sklearn.metrics模块的相关资料链接,方便高端玩家深入探索。
二、本文包含的评价指标
混淆矩阵(Confusion Matrix,CM)每类别分类精度每类别召回率平均分类精度(Average Accuracy,AA)总体分类精度(Overall Accuracy,OA)
三、功能代码片段展示
代码在tensorflow-1.0、Python3.5环境下通过测试,tf1.0版本API改动较大,1.0以下版本tensorflow可能不能通过测试,精力有限,其他环境尚未做测试。
1 from sklearn importmetrics2 importnumpy as np3 #####
4 #Do classification task,
5 #then get the ground truth and the predict label named y_true and y_pred
6 classify_report =metrics.classification_report(y_true, y_pred)7 confusion_matrix =metrics.confusion_matrix(y_true, y_pred)8 overall_accuracy =metrics.accuracy_score(y_true, y_pred)9 acc_for_each_class = metrics.precision_score(y_true, y_pred, average=None)10 average_accuracy =np.mean(acc_for_each_class)11 score =metrics.accuracy_score(y_true, y_pred)12 print('classify_report : \n', classify_report)13 print('confusion_matrix : \n', confusion_matrix)14 print('acc_for_each_class : \n', acc_for_each_class)15 print('average_accuracy: {0:f}'.format(average_accuracy))16 print('overall_accuracy: {0:f}'.format(overall_accuracy))17 print('score: {0:f}'.format(score))
四、实验结果展示
本文基于tensorflow-1.0框架与mnist数据集,使用线性分类器与卷积神经网络分类并使用上文提到的代码片段展示分类性能。
分类性能结果直观,排列清晰,便于二次使用。
1. 线性分类器分类报告:
2. 线性分类器混淆矩阵与其他分类指标展示:
3. 卷积神经网络每层参数显示:
4. 卷积神经网络分类报告:
5. 卷积神经网络混淆矩阵与其他分类指标展示:
五、代码示例
使用类似如下的代码片段可以直观查看tensor相关内容
1 print(some_tensor.op.name, ' ', some_tensor.get_shape().as_list())