python混淆矩阵函数_混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)...

本文介绍了机器学习中混淆矩阵的概念、在scikit-learn和TensorFlow中的API实现,以及如何通过示例展示其在评估分类模型性能中的作用。它展示了如何计算混淆矩阵并解读其在二分类任务中的元素含义,同时提到了不同库中的参数设置和示例代码。
摘要由CSDN通过智能技术生成

原理

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.

混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.

使用混淆矩阵( scikit-learn 和 Tensorflow)

下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.

scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口

skearn.metrics.confusion_matrix(

y_true,#array, Gound true (correct) target values

y_pred, #array, Estimated targets as returned by a classifier

labels=None, #array, List of labels to index the matrix.

sample_weight=None #array-like of shape = [n_samples], Optional sample weights

)

在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.

按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.

如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.

Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口

tf.confusion_matrix(

labels,#1-D Tensor of real labels for the classification task

predictions, #1-D Tensor of predictions for a givenclassification

num_classes=None, #The possible number of labels the classification task can have

dtype=tf.int32, #Data type of the confusion matrix

name=None, #Scope name

weights=None, #An optional Tensor whose shape matches predictions

)

Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为 num_classes 参数值.

tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.

使用示例

#!/usr/bin/env python#-*- coding: utf8 -*-

"""

Author: klchang

Description:

A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.

Date: 2018.9.8

"""from __future__ importprint_functionimporttensorflow as tfimportsklearn.metrics

y_true= [1, 2, 4]

y_pred= [2, 2, 4]#Build graph with tf.confusion_matrix operation

sess =tf.InteractiveSession()

op=tf.confusion_matrix(y_true, y_pred)

op2= tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))#Execute the graph

print ("confusion matrix in tensorflow:")print ("1. default: \n", op.eval())print ("2. customed: \n", sess.run(op2))

sess.close()#Use sklearn.metrics.confusion_matrix function

print ("\nconfusion matrix in scikit-learn:")print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred))print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))

参考资料

1. Confusion matrix. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Confusion_matrix

2. Contingency table. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Contingency_table

3. Tensorflow API - tf.confusion_matrix. https://www.tensorflow.com/api_docs/python/tf/confusion_matrix

4.  scikit-learn API - sklearn.metrics.confusion_matrix. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值