# coding: utf-8
import os
import torch
import torchvision.utils as vutils
import numpy as np
import torchvision.models as models
from torchvision import datasets
from tensorboardX import SummaryWriter
import itertools
import matplotlib.pyplot as plt
writer = SummaryWriter(os.path.join("..", "..", "Result", "runs"))
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
fig = plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
return fig
cnf_matrix = np.array([
[4101, 2, 5, 24, 0],
[50, 3930, 6, 14, 5],
[29, 3, 3973, 4, 0],
[45, 7, 1, 3878, 119],
[31, 1, 8, 28, 3936],
])
class_names = ['Buildings', 'Farmland', 'Greenbelt', 'Wasteland', 'Water']
#调用add_figure将figure放入tensorboardX中显示
writer.add_figure('confusion matrix',figure=plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False,title='Normalized confusion matrix'),global_step=1)
writer.add_figure('confusion matrix',figure=plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,title='Normalized confusion matrix'),global_step=1)
writer.close()