import sklearn.metrics
import matplotlib.pyplot as plt
import numpy as np
import torch
confusion_matrix = torch.tensor([0.8225, 0.8302, 0.8374, 0.8361, 0.4787, 0.832, 0.8324, 0.8415, 0.841, 0.4332, 0.8352,
0.8394, 0.8374, 0.8377, 0.422, 0.8272, 0.8307, 0.8392, 0.8404, 0.4478, 0.4536, 0.4312,
0.4233, 0.4337, 0.6849])
confusion_matrix = confusion_matrix.reshape(5,5)
confusion_matrix = confusion_matrix*100
fig, axes = plt.subplots(1)
plt.imshow(confusion_matrix, cmap='Blues')
class_names = [0.2, 0.4, 0.6, 0.8, 1.0]
plt.title('Cora')
axes.set_xticks([i for i in range(len(class_names))])
axes.set_yticks([i for i in range(len(class_names))])
axes.set_xticklabels(class_names, ha='right', fontsize=8, rotation=40)
axes.set_yticklabels(class_names, ha='right', fontsize=8)
for (i, j), z in np.ndenumerate(confusion_matrix):
axes.text(j, i, '%.2f' % z, ha='center', va='center', color='black', fontsize=12)
plt.tight_layout()
plt.show()
plt.close()
根据值画颜色
最新推荐文章于 2024-09-15 22:31:42 发布