from sklearn.metrics import roc_curve,auc
from prettytable import PrettyTable
class ConfusionMatrix(object):
"""
注意,如果显示的图像不全,是matplotlib版本问题
本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
需要额外安装prettytable库
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
table = PrettyTable()
table.field_names = ["", "Precision", "Recall", "Specificity","f1_score"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
f1_score = 1./(1/(Precision+0.00001) + 1/(Recall+0.00001))
table.add_row([self.labels[i], Precision, Recall, Specificity,f1_score])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
plt.xticks(range(self.num_classes), self.labels, rotation=45)
plt.yticks(range(self.num_classes), self.labels)
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
def save_checkpoint(state,filename="/home/tlz/GCCS_0916/checkpoint"):
print("=> Saving checkpoint")
torch.save(state,filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
def check_accuracy(loader, model,test,epoch,path=None):
confusion=ConfusionMatrix(num_classes=2,labels=[0,1])
num_correct = 0
num_samples = 0
model.eval()
y_sum=np.array(y_sum)
scores_sum=[]
scores_sum=np.array(scores_sum)
with torch.no_grad():
for x, y in tqdm(loader):
x = x.to(device=device)
y = y.to(device=device)
scores=model(x)
y_sum=np.append(y_sum,y.to(torch.device('cpu')).numpy())
scores_sum=np.append(scores_sum,scores.to(torch.device('cpu')).numpy())
predictions = torch.tensor([0 if i <0.5 else 1 for i in scores]).to(device=device)
confusion.update(predictions.to("cpu").numpy(),y.to("cpu").numpy())
if test:
fpr, tpr, thresholds = roc_curve(np.array(y_sum), np.array(scores_sum), pos_label=1)
roc_auc = auc(fpr, tpr)
plt.figure()
lw=2
plt.plot(fpr, tpr, color='y',
lw=lw, label='Original (AUC = %0.4f)' % roc_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
jpg_name = os.path.join(path, str(epoch)+ '.jpg')
plt.savefig(jpg_name)
plt.close()
confusion.summary()
model.train()
return None