习惯了转载进行Mark,现在是来还债时候了,一个一个补齐。
import torch
class ConfusionMatrix:
def __init__(self, num_classes):
"""
初始化混淆矩阵
Args:
num_classes (int): 类别数量
"""
self.num_classes = num_classes
self.matrix = torch.zeros((num_classes, num_classes), dtype=torch.int64)
def update(self, preds, targets):
"""
更新混淆矩阵
Args:
preds (torch.Tensor): 模型预测的类别标签,形状为 [N, ...]
targets (torch.Tensor): 真实类别标签,形状需与 preds 一致
"""
# 确保输入为一维张量
preds = preds.flatten()
targets = targets.flatten()
# 确保数据类型为长整型
preds = preds.to(torch.int64)
targets = targets.to(torch.int64)
# 过滤无效数据(标签超出类别范围)
mask = (targets >= 0) & (targets < self.num_classes)
targets = targets[mask]
preds = preds[mask]
# 计算线性索引
indices = targets * self.num_classes + preds
# 统计出现次数
counts = torch.bincount(
indices,
minlength=self.num_classes ** 2
).reshape(self.num_classes, self.num_classes)
# 更新矩阵
self.matrix += counts.to(self.matrix.device)
def compute(self):
"""
返回当前混淆矩阵
"""
return self.matrix
def reset(self):
"""
重置矩阵
"""
self.matrix.zero_()
# 示例用法
if __name__ == "__main__":
num_classes = 3
cm = ConfusionMatrix(num_classes)
# 模拟数据(通常从数据加载器和模型获取)
targets = torch.tensor([0, 1, 2, 0, 1, 2])
preds = torch.tensor([0, 1, 1, 0, 2, 2])
cm.update(preds, targets)
matrix = cm.compute()
print("Confusion Matrix:")
print(matrix)
# 输出解释:
# matrix[i][j] 表示真实类别 i 被预测为类别 j 的次数
输出结果:
Confusion Matrix:
tensor([[2, 0, 0],
[0, 1, 1],
[0, 1, 1]])