语义分割任务中使用混淆矩阵计算IOU的原理及代码解读

本文介绍了IOU(IntersectionoverUnion)在语义分割中的重要性,通过混淆矩阵实例说明其计算方法。作者详细解读了如何基于混淆矩阵计算类别间的交并比,包括处理预测结果和Mask,以及构建和更新混淆矩阵的代码实现。
摘要由CSDN通过智能技术生成

IOU:Intersection over Union,就是交集比上并集,简称为交并比。

IOU是衡量语义分割模型性能的一个重要指标,IOU越大说明预测的结果和Mask上对应类别的重叠度越大。本文主要记录了个人对于混淆矩阵以及对应代码的理解。

首先,举个简单的例子说明下什么是混淆矩阵:

假如一张图片上有八种物体,那么在语义分割中这八种物体的像素各自有一个对应的类别索引,我们用0到7表示,那么其对应的混淆矩阵是一个8行8列的矩阵,矩阵的行索引0~7对应着真实的类别索引,矩阵的列索引对应着预测结果的类别索引。在混淆矩阵中对角线上的位置表示属于类别i且被预测为类别i的像素数,即类别i中预测正确的像素数。混淆矩阵的特点是每一行的和是该类别总的像素数,每一列的和是预测为该类别的总像素数。此时就可以根据混淆矩阵的这一特点来计算IOU了,某个类别的IOU=该类别对角线上的值/(该类别所在行的和+该类别所在列的和-该类别对角线上的值),得到就是该类别的交并比。那么以此类推可以算出每个类别的IOU,最后在求平均的话就是平均交并比mIOU。

可以用以下代码建立一个全零的混淆矩阵:

mat = torch.zeros((8, 8), dtype=torch.int64)  # num_classes=8
>>>mat
>>>tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])

假设现在有一张图像,其经过模型处理输出的预测结果是(3,4,4),该图像对应的Mask的形状是(1,4,4),在计算IOU之前需要首先对输出结果和Mask进行处理,将Mask沿着0维展平,将预测结果在通道维上逐像素取最大值,然后沿0维展平,也就是说预测结果和Mask都要展成一维的向量进行处理。

mask.flatten()
output.argmax(1).flatten()

接下来是根据预测结果和Mask得到对应的混淆矩阵的代码,也是我琢磨了好一会的地方

n=8  # num_classes
k = (a >= 0) & (a < n)  # a is flatten mask ,b is flatten output
inds = n * a[k].to(torch.int64) + b[k]
mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

为了搞明白这几行代码的原理,令a = torch.tensor([  3, 255, 255,   4,   4,   5,   4,  -1,   3,   3,   7,   7, 255,   2,  2,   3]),b=tensor([  3, 255, 255,   3,   3,   5,   3, 255,   3,   3,   7,   7, 255,   2, 2,   3])。

k = (a>=0)&(a<n)这行代码的意思是将Mask中类别ID在0到n-1之间的像素给标记为True,得到的k是tensor([ True, False, False,  True,  True,  True,  True, False,  True,  True, True,  True, False,  True,  True,  True])。

a[k]的意思是将Mask中有效的类别ID检索出来,a[k]=tensor([3, 4, 4, 5, 4, 3, 3, 7, 7, 2, 2, 3])

b[k]的意思是检索出预测结果中和Mask相同位置的预测值,b和a中相同位置的值如果相同的话说明这个位置预测正确了,b[k]=tensor([3, 3, 3, 5, 3, 3, 3, 7, 7, 2, 2, 3]),红色位置的是预测错误的像素。

inds=n * a[k].to(torch.int64) + b[k]的作用要和torch.bincount(inds, minlength=n**2)一起看,这两行代码中bincount()的作用是统计inds中[0-max_value]内每个元素出现的次数,minlength表示输出的最小长度。

n*a[k]=tensor([24, 32, 32, 40, 32, 24, 24, 56, 56, 16, 16, 24])

inds=tensor([27, 35, 35, 45, 35, 27, 27, 63, 63, 18, 18, 27])

torch.bincount(inds, minlength=8**2)输出是长度为64的一维张量,其中的每一个元素的索引代表了inds中的一个值,该元素的大小代表了其索引值在inds中出现的次数。reshap操作将bincount的结果转换成(8,8)的张量,这就是得到的混淆矩阵。

torch.bincount(inds, minlength=8**2).reshape(8,8)

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 0, 0, 0, 0, 0],
        [0, 0, 0, 4, 0, 0, 0, 0],
        [0, 0, 0, 3, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 2]])

此时回头看inds=n * a[k].to(torch.int64) + b[k],以a[1]=4,b[1]=3为例,8*4+3=35。35这个值在混淆矩阵中的位置是第5行的第4个元素(对应的预测类别3),在混淆矩阵中第5行代表的是Mask中类别为4的像素,第5列代表的是预测结果中类别为4的像素,也就是说类别为4的像素被预测成了类别3。这样就明白了为什么a[k]乘以n作用是为了保证预测结果出现在混淆矩阵中正确的位置,就可以在混淆矩阵中统计出真是类别a[k]被统计成类别a[k]的个数。

类别4的IOU:0/(3+0-0)=0,同理类别2的IOU:2/(2+2-2)=1。

mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
之所以要把混淆矩阵累加到mat上,是因为我们的数据集通常有很多图片,所以要将每个图片的混淆矩阵累加起来,这样最后得到的mat就是整个数据集上所有图片的预测结果,从而可以计算整个数据集上每个类别的IOU和平均的mIOU。

混淆矩阵的完整代码如下:

class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        # a是标签,b是预测的结果
        # a和b都是展成一维张量后进来的
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
           
            # >=0表示类别索引从0开始
            k = (a >= 0) & (a < n)
            inds = n * a[k].to(torch.int64) + b[k]
            # 若inds中的最大值是15,那么torch.bincount(inds, minlength=n**2)统计的是inds中0到15每个元素出现的次数,minlength表示输出的最小长度
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        # 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        acc_global = torch.diag(h).sum() / h.sum()
        # 计算每个类别的准确率
        acc = torch.diag(h) / h.sum(1)
        # 计算每个类别预测与真实目标的iou
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.mat)

    def __str__(self):

        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'  # 每个类别分类正确的像素除以这个类别的像素总数
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
                acc_global.item() * 100,
                ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
                ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
                iu.mean().item() * 100)

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值