下面是使用PyTorch计算Mean Intersection over Union(mIoU)的代码:
import torch
def compute_miou(pred, target, num_classes):
"""
计算Mean Intersection over Union(mIoU)
参数:
- pred: 预测的标签,大小为[N, H, W],N为批次大小,H为图像高度,W为图像宽度
- target: 真实标签,大小为[N, H, W]
- num_classes: 类别数
返回:
- miou: mIoU值
"""
# 将预测和目标标签转换为一维数组
pred = pred.view(-1)
target = target.view(-1)
# 创建混淆矩阵
confusion_matrix = torch.zeros(num_classes, num_classes)
# 计算预测和目标标签之间的交集和并集
for p, t in zip(pred, target):
confusion_matrix[p, t] += 1
# 计算每个类别的IoU
intersection = confusion_matrix.diag()
union = confusion_matrix.sum(0) + confusion_matrix.sum(1) - intersection
# 计算每个类别的IoU并求平均
iou = intersection / union
miou = iou.mean()
return miou
使用时,只需将预测的标签和真实标签传入compute_miou函数,以及类别数,即可计算得到mIoU值。