import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
# Set font size and family
mpl.rcParams['font.size'] =12
mpl.rcParams['font.family'] = 'Times New Roman'
# Load confusion matrix from CSV file
confusion_matrix = np.loadtxt('../confusion_matrix.csv', delimiter=',')
# Define class labels
labels = ['Airport', 'BareLand', 'BaseballField', 'Beach', 'Bridge', 'Center', 'Church', 'Commercial', 'DenseResidential', 'Desert', 'Farmland', 'Forest', 'Industrial', 'Meadow', 'MediumResidential', 'Mountain', 'Park', 'Parking', 'Playground', 'Pond', 'Port', 'RailwayStation', 'Resort', 'River', 'School', 'SparseResidential', 'Square', 'Stadium', 'StorageTanks', 'Viaduct']
# Set figure size
fig, ax = plt.subplots(figsize=(14,12))
# Plot heatmap with origin set to 'lower'
im = ax.imshow(confusion_matrix, cmap='Blues', extent=[0, len(labels), 0, len(labels)], origin='lower')
# Set ticks at the middle of each cell
ticks = np.arange(len(labels))
plt.xticks(ticks+0.5, labels, rotation=90, fontsize=12)
plt.yticks(ticks+0.5, labels, fontsize=12)
# Hide tick marks
plt.tick_params(axis='both', which='both', length=0)
# Add color bar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8)
# Set axis labels and title
plt.xlabel('Predicted Labels', fontsize=16)
plt.ylabel('True Labels', fontsize=16)
plt.title('Confusion Matrix',fontsize=17)
# Add text labels to each cell
for i in range(len(labels)):
for j in range(len(labels)):
if confusion_matrix[i, j] != 0:
text = ax.text(j + 0.5, i + 0.5, '{:.1f}'.format(confusion_matrix[i, j]), ha='center', va='center', color='black', fontsize=10)
# Invert y-axis
ax.invert_yaxis()
# Show plot
plt.show()
分类任务绘制混淆矩阵
最新推荐文章于 2024-07-22 20:35:33 发布