mIOU计算代码

下面是使用PyTorch计算Mean Intersection over Union(mIoU)的代码:

import torch

def compute_miou(pred, target, num_classes):
    """
    计算Mean Intersection over Union(mIoU)
    
    参数:
    - pred: 预测的标签,大小为[N, H, W],N为批次大小,H为图像高度,W为图像宽度
    - target: 真实标签,大小为[N, H, W]
    - num_classes: 类别数
    
    返回:
    - miou: mIoU值
    """
    
    # 将预测和目标标签转换为一维数组
    pred = pred.view(-1)
    target = target.view(-1)
    
    # 创建混淆矩阵
    confusion_matrix = torch.zeros(num_classes, num_classes)
    
    # 计算预测和目标标签之间的交集和并集
    for p, t in zip(pred, target):
        confusion_matrix[p, t] += 1
    
    # 计算每个类别的IoU
    intersection = confusion_matrix.diag()
    union = confusion_matrix.sum(0) + confusion_matrix.sum(1) - intersection
    
    # 计算每个类别的IoU并求平均
    iou = intersection / union
    miou = iou.mean()
    
    return miou

使用时,只需将预测的标签和真实标签传入compute_miou函数,以及类别数,即可计算得到mIoU值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值