关于语义分割评价指标和混淆矩阵的基本知识和代码这篇博客讲得很详细,我这里只是详解一下混淆矩阵计算的代码。
# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
很佩服写出这段代码的大佬,几乎一句代码就计算出了混淆矩阵。
很多博客也有讲,但是我觉得其中的核心代码没有解释清楚,大多都是告诉我们这个函数最终可以得到混淆矩阵,但具体怎么就得到混淆矩阵了还是搞不清楚。
先看输入:label_true, label_pred, n_class
label_true,表示真实的标记,这里是一个二维数组,也可以理解为一张灰度图,每个像素点的这对应于一个类别(用数字表示0,1,2,…)
label_pred,表示预测结果,格式同label_true
n_class,类别总数
假设输入label_true为:
假设输入label_pred为:
n_class = 3
函数第一句:
mask = (label_true >= 0) & (label_true < n_class)
这一句是为了保证标记的正确性(标记的每个元素值在[0, n_class)内),标记正确得到的mask是一个全为true的数组。
函数第二句:
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
这一句是计算的核心,步骤有点多,下面拆分成几个小步来分析理解。(np.bincount函数学习链接:numpy.bincount详解)
为方便书写,记 a = label_true[mask].astype(int) ,即label_true被拉平为一维数组
记 b = label_pred[mask]
则,这里的操作即相当于np.bincount(a*n + b).reshape(n, n)
(n_class记为n, minlength=n_class ** 参数是为了保证输出向量的长度为n_class * n_class。)
记 c = a*n + b, d = np.bincount( c )
根据np.bincout的特性,c中元素的每一个值是为d中以其值为index的元素+1,也就是说c中元素的值其实是对应与d的index,即d = np.bincount(c, minlength=n**)的计算相当于:
d = np.zeros((n*n,))
for ci in c:
d[ci] +=1
再将d.reshape(n, n)
则d[i, j] 的值就为,i*n + j 的值在c中出现的次数,而i是a中的值,j是b中的值,且它们在a与b相同的位置处,恰好代表了真实类别与预测类别,即d[i, j] 代表了预测结果为类别 j,实际标签为类别 i 的所有像素点的数目。
这恰好与混淆矩阵的定义相同!