一、多标签混淆矩阵
cm= multilabel_confusion_matrix(y_test,y_pred_df )
这里的y_test与y_pred_df需要是二进制的多标签编码
多标签分类的混淆矩阵与多分类混淆矩阵不同,不再是二维矩阵,而变成了三维矩阵,第一维度是多个标签类别数的相加,如我的数据有两个标签,第一个标签下有六类,第二个标签有16类,那么多标签混淆矩阵的第一维度就是22,其形状是22*2*2
[[[123 0]
[ 1 20]]
[[133 0]
[ 2 9]]
[[134 0]
[ 1 9]]...........]
也就是有22个2*2的小矩阵组成。
第一个矩阵代表第一类的分类情况(我的数据类别标签在第一篇博文中)
第一行数据的第一个,123表示不是第一类并且判断时也认为不是第一类的样本数,第一行数据第二个,0表示不是第一类但判断为第一类的样本数
第二行数据第一个,1表示是第一类但判断结果不是第一类,第二个数据20,表示是第一类且判断结果也是第一类
二、多标签混淆矩阵的可视化
需要将每个小矩阵都画成热图,然后组合在一起
import seaborn as sns import matplotlib.pyplot as plt matrix_list = cm # 设置画布大小 plt.figure(figsize=(10, 10)) # 遍历小矩阵列表 labels = ["".join("c" + str(i)) for i in range(0, 22)] for i, matrix in enumerate(matrix_list): labels = ["A", "A1", "A2", "B", "B1", "B2", "B3", "B4", "C", "C1", "C2", "C3", "D", "D1", "D2", "E", "E1", "E2", "F", "F1", "F2", "F3"]#这里是我的类别,要作为热图的标题 plt.subplot(6, 4, i + 1) # 创建一个子图,共有6行4列,当前子图为第i+1个子图 sns.heatmap(matrix, annot=True,fmt="d", cmap="Blues", cbar=False,xticklabels=["N","Y"],yticklabels=["N","Y"]) # 绘制热图,不显示颜色条 plt.title(f"class{labels[i]}")#为每个热图加上对应的类别标题 # 调整子图之间的间距 plt.tight_layout() # 显示图像 plt.show()