在机器学习中,我们会利用一些指标(混淆矩阵、精确率、召回率、F1值、准确率)来判断我们模型的好坏,从而改进优化模型。下面介绍如何在TensorFlow下快速计算这些指标。
1、混淆矩阵
confusion_matrix = tf.contrib.metrics.confusion_matrix(labels_pred_all, labels_all, num_classes=None, dtype=tf.int32, name=None, weights=None)
confusion_matrix = sess.run(confusion_matrix)
因为第一步所计算出来的混淆矩阵是一个Tensor,所以需要进行转换。
具体api详解:
https://haosdent.gitbooks.io/tensorflow-document/content/api_docs/python/contrib.metrics.html#confusion_matrix
值得注意的是:所计算出来的混淆矩阵,列是真实值(也就是期望值),行是预测值
2、四大指标:
有了混淆矩阵,计算四大指标就好办了。
accu = [0,0,0,0,0]
column = [0,0,0,0,0]
line = [0,0