针对celeb数据集发现数据样本及其不均衡导致结果全部预测为样本多的数据
为此记录下其他的指标学习
更详细内容请查看余霆嵩pytorch教程
前言:
先看一个经典的例子:
目的:设计一个分类器类分类患者的肿瘤是不是良性的
数据集:我们有10000个样本,其中9995个是良性的,只有5个人是恶性的。
如果我们只关注预测的错误率的,我们可以设计模型无论什么样本输入,全部输出良性。此时的准确率是99.95%,是不是很满意呢?但这种模型根本没有用。所以说对于数据的类别存在偏析的情况,不能只看错误率。
因此出现了查准率、查全率、ROC、AUC…
例子
从第一行中可知道,10 张猫的图像中,7 张预测为猫,3 张预测为狗,猫的召回率(Recall)为 7/10 = 70%,
从第二行中可知道,30 张狗的图像中,8 张预测为猫,22 张预测为狗,狗的召回率为20/30 = 66.7%,
从第一列中可知道,预测为猫的 17 张图像中,有 7 张是真正的猫,猫的精确度
(Precision)为 7 / 17 = 41.17%
从第二列中可知道,预测为狗的 23 张图像中,有 20 张是真正的狗,狗的精确度
(Precision)为 20 / 23 = 86.96%
模型的准确率(Accuracy)为 7+20 / 40 = 67.5%
FN、TN、FP、TP
**查准率Precision(也叫准确率)**表示 你预测为正的样本中(TP+FP)有多少是真正的正样本(TP)(因为还有一些你错误的预测为正样本(FP)
**查全率(Recall),又叫召回率,**缩写表示用R。查全率是针对我们原来的样本而言的,它表示的是样本中的正例有多少被预测正确。
混淆矩阵代码
- 混淆矩阵的统计
第一步:创建混淆矩阵
获取类别数,创建 N*N 的零矩阵
conf_mat = np.zeros([cls_num, cls_num])
第二步:获取真实标签和预测标签
labels 为真实标签,通常为一个 batch 的标签
predicted 为预测类别,与 labels 同长度
第三步:依据标签为混淆矩阵计数
for i in range(len(labels)):
true_i = np.array(labels[i])
pre_i = np.array(predicted[i])
conf_mat[true_i, pre_i] += 1.0