一、混淆矩阵的计算、绘制
二、图像分割常用指标
一、混淆矩阵
1.1 混淆矩阵介绍
之前介绍过二分类混淆矩阵:《混淆矩阵、错误率、正确率、精确度、召回率、F1值、PR曲线、ROC曲线、AUC》
现在说一下多分类混淆矩阵。其实是一样的,就是长下面这样。
只有对角线上的点是分类正确的点。
有了混淆矩阵之后,就可以求各种率了。比如正确率、错误率、召回率等等。
1.2 混淆矩阵计算
(1)原理
知道了标签与预测之后,就可以得到混淆矩阵了。比如,下面的数据。共有10个数据,类别有6类。第一行是标签。第二行是预测的结果。
可以统计得到下面的混淆矩阵:
我觉得统计的方法就好的。首先初始化一个二维的表格。然后遍历一遍数据,在表格对应的位置+1。
不过有的人给了下面一个方法,也挺巧妙的,下面介绍一下。
第一步:计算n*Lable + predict。n为类别数,这里有6类,所以n=6。
第二步:初始化一个长度为n^2的数组。统计一下n*L+P数组中元素的个数,填写到对应的位置。
这一步可以使用numpy库的bincount()方法实现。
例:
结果:
第三步:把这个矩阵reshape()成n*n,就是混淆矩阵。
有没有很神奇。。。其实挺简单的。。。留作课后作业了~~
(2) 代码实现
import numpy as np
#假设每张图片的大小是2*5。
gt_label = np.array([[0, 1, 2, 3, 1],
[1, 2, 2, 3, 4]])
pre_label = np.array([[0, 1, 2, 3, 1],
[5, 1, 2, 1, 4]])
n_class = 6 #有6类
#验证标签取值是对的
mask = (gt_label >= 0) & (gt_label < n_class)
print(mask)
#通过mask可以把为true的元素挑出来,并且拉成一行
print(gt_label[mask]) #得到的是一维数组
print(pre_label[mask])
#计算混淆矩阵
x = n_class * gt_label[mask] + pre_label[mask] #第一步:计算n*Lable + predict
print(x)
res = np.bincount(x, minlength=n_class**2).reshape(n_class, n_class) #第二三步:使用numpy库的bincount()方法统计,并reshape
print(res)
结果:
计算结果正确。
将上面的方法封装成函数,如下: