小样本分类生成全局混淆矩阵

小样本分类因其使用元学习的训练方法,每个episode只选取其中几个类别进行训练和测试。如下图所示

因此这不像传统的图像分类,生成的混淆矩阵直接就是全部类别的分类,所以我们要想办法将小样本分类中的每个episode聚合起来生成一个含有全部类别的混淆矩阵。代码如下:

# 使用当前 episode 的类别进行标签映射
mapped_gt = [episode_classes[label] for label in gt]
mapped_pred = [episode_classes[label] for label in pred]

# 生成当前 episode 的混淆矩阵,并调整到全局混淆矩阵的尺寸
C1 = confusion_matrix(mapped_gt, mapped_pred, labels=range(len(global_class_mapping)))

global_conf_matrix += C1

在main函数中调用上述代码存在的函数时需传参

#局部和全局的映射
global_class_mapping = {i: i for i in range(params.current_class)}
#初始化全局混淆矩阵
global_conf_matrix = np.zeros((params.current_class, params.current_class))

此时即可生成全局混淆矩阵,查看可使用print(global_conf_matrix)。结果如下图所示

如果想将其生成为热力图的形式,可将生成的混淆矩阵进行保存,在jupyter notebook中进行运行。保存文件代码如下:

with open("confusion_matrix_log.txt", "w") as f:
    np.savetxt(f, global_conf_matrix2, fmt="%d")

生成热力图的代码如下:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 从日志文件读取混淆矩阵
def load_confusion_matrix_from_log(log_file):
    return np.loadtxt(log_file, dtype=int)

# 生成并显示热力图
def plot_confusion_matrix(conf_matrix, class_names):
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix_normalized, annot=True, fmt=".2f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix Heatmap')
    plt.show()

if __name__ == "__main__":
    log_file = "confusion_matrix_log.txt"
    conf_matrix = load_confusion_matrix_from_log(log_file)
    
    # 假设全局类别标签为 0 到 N-1
    class_names = [str(i) for i in range(conf_matrix.shape[0])]
    
    plot_confusion_matrix(conf_matrix, class_names)

结果如下图所示:

但是小样本分类每次选取的类别不固定,所以每个类别用来预测的样本总数也不尽相同,因此可以将热力图改为概率的形式,代码如下:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 从日志文件读取混淆矩阵
def load_confusion_matrix_from_log(log_file):
    return np.loadtxt(log_file, dtype=int)

# 生成并显示热力图
def plot_confusion_matrix(conf_matrix, class_names):
    # 将混淆矩阵归一化为概率值
    conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix_normalized, annot=True, fmt=".2f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix Heatmap')
    plt.show()

if __name__ == "__main__":
    log_file = "confusion_matrix_log.txt"
    conf_matrix = load_confusion_matrix_from_log(log_file)
    
    # 假设全局类别标签为 0 到 N-1
    class_names = [str(i) for i in range(conf_matrix.shape[0])]
    
    plot_confusion_matrix(conf_matrix, class_names)

结果如下图所示:

  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值