小样本分类因其使用元学习的训练方法,每个episode只选取其中几个类别进行训练和测试。如下图所示
因此这不像传统的图像分类,生成的混淆矩阵直接就是全部类别的分类,所以我们要想办法将小样本分类中的每个episode聚合起来生成一个含有全部类别的混淆矩阵。代码如下:
# 使用当前 episode 的类别进行标签映射
mapped_gt = [episode_classes[label] for label in gt]
mapped_pred = [episode_classes[label] for label in pred]
# 生成当前 episode 的混淆矩阵,并调整到全局混淆矩阵的尺寸
C1 = confusion_matrix(mapped_gt, mapped_pred, labels=range(len(global_class_mapping)))
global_conf_matrix += C1
在main函数中调用上述代码存在的函数时需传参
#局部和全局的映射
global_class_mapping = {i: i for i in range(params.current_class)}
#初始化全局混淆矩阵
global_conf_matrix = np.zeros((params.current_class, params.current_class))
此时即可生成全局混淆矩阵,查看可使用print(global_conf_matrix)。结果如下图所示
如果想将其生成为热力图的形式,可将生成的混淆矩阵进行保存,在jupyter notebook中进行运行。保存文件代码如下:
with open("confusion_matrix_log.txt", "w") as f:
np.savetxt(f, global_conf_matrix2, fmt="%d")
生成热力图的代码如下:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 从日志文件读取混淆矩阵
def load_confusion_matrix_from_log(log_file):
return np.loadtxt(log_file, dtype=int)
# 生成并显示热力图
def plot_confusion_matrix(conf_matrix, class_names):
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_normalized, annot=True, fmt=".2f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix Heatmap')
plt.show()
if __name__ == "__main__":
log_file = "confusion_matrix_log.txt"
conf_matrix = load_confusion_matrix_from_log(log_file)
# 假设全局类别标签为 0 到 N-1
class_names = [str(i) for i in range(conf_matrix.shape[0])]
plot_confusion_matrix(conf_matrix, class_names)
结果如下图所示:
但是小样本分类每次选取的类别不固定,所以每个类别用来预测的样本总数也不尽相同,因此可以将热力图改为概率的形式,代码如下:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 从日志文件读取混淆矩阵
def load_confusion_matrix_from_log(log_file):
return np.loadtxt(log_file, dtype=int)
# 生成并显示热力图
def plot_confusion_matrix(conf_matrix, class_names):
# 将混淆矩阵归一化为概率值
conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_normalized, annot=True, fmt=".2f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix Heatmap')
plt.show()
if __name__ == "__main__":
log_file = "confusion_matrix_log.txt"
conf_matrix = load_confusion_matrix_from_log(log_file)
# 假设全局类别标签为 0 到 N-1
class_names = [str(i) for i in range(conf_matrix.shape[0])]
plot_confusion_matrix(conf_matrix, class_names)
结果如下图所示: