基于混淆矩阵的多类IoU计算方式理解

1. IoU计算方式说明的相关文章:

[1] https://blog.csdn.net/qq_21368481/article/details/80424754

[2] ​​​​​​https://www.daimajiaoliu.com/daima/479a7286b90041c

2. 自己实现混淆矩阵

import numpy as np

gt = np.array([[0,1,0],[2,1,0],[2,2,1]])
pd = np.array([[0,2,0],[2,1,0],[1,2,1]])
class_num = 3
mm = np.zeros_like(gt)

for i in range(class_num):
    for j in range(class_num):
        mm[i,j] = np.sum((gt==i) * (pd==j))

print(mm)

输出结果:

[[3 0 0]
 [0 2 1]
 [0 1 2]]

Process finished with exit code 0

与参考文献1的结果一致。

3. 基于np.bincount 混淆矩阵计算方式理解:

        np.bincount(n * gt[mask].flatten() + pd[mask].flatten())。

首先举一个例子,二维矩阵的索引与二维矩阵经过扁平化后的一维矩阵的索引的对应关系:对于一个维度为N*N的矩阵,维度调整成一维。则,二维矩阵中第i行第j列元素与一维的向量第Index的对应关系为 Index = i * N  + j。

图像的真实标注gt,与模型的预测结果中存放0,..,N-1.(N表示类别的数量。),例如:

ground truth is 

 [[0 1 0]
 [2 1 0]
 [2 2 1]]

prediction result is
 
 [[0 2 0]
 [2 1 0]
 [1 2 1]]

我们的目的是构造一个混淆矩阵(N*N,N是类的数量),来统计真实类别与预测类别之间的TP,TN,FP,FN数量。混淆矩阵的对角线记录TP的数量,混淆矩阵的每一行记录真实类别i,被预测成类别i,与其他类的情况。gt、pd中存放的标签数值正好对应 混淆矩阵 的行索引,列索引。混淆矩阵的每一列j记录其他真实类被预测成类j的情况。

例如:N=3。一维的混淆矩阵的索引为: 0,1,2,3,4,5,6,7,8。

gt:

010
210
221

pd:

020
210
121

通过这种方式 n * gt[mask].flatten() + pd[mask].flatten() 计算出一维混淆矩阵的索引。然后再利用np.bincount统计出[0,8]范围内索引出现的频率,从而计算出类i 被预测成类 j的数量。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值