下载包
python -m pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple
导入包
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
模型评估
def evaluate(model, data_loader, device):
model.eval()
# 验证样本总个数
total_num = len(data_loader.dataset)
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 初始化存储预测结果和真实标签的列表
all_labels = []
all_preds = []
all_probs = []
cont = 0
outPre = []
outLabel = []
loss_function =torch.nn.CrossEntropyLoss().cuda()
# loss_function.to(device)
mean_loss = torch.zeros(1).to(device)
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
mean_loss = (mean_loss * step + loss) / (step + 1) # update mean losses
pred = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred, labels.to(device)).sum()
all_labels.extend(labels.cpu().numpy())
preds = pred.cpu().numpy()
all_preds.extend(preds)
if cont == 0:
outPre = pred.data.cpu()
outLabel = labels.data.cpu()
else:
outPre = torch.cat((outPre, pred.data.cpu()), 0)
outLabel = torch.cat((outLabel, labels.data.cpu()), 0)
cont += 1
y_pred = all_preds
y_true = all_labels
report = classification_report(y_true, y_pred, target_names=['Class 0', 'Class 1', 'Class 2'])
print(report)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)
print("Precision:\t\t", precision)
print("Recall:\t\t", recall)
print("F1 Score:\t\t", f1)
# 计算微平均和宏平均的precision, recall, f1-score
precision_micro = precision_score(y_true, y_pred, average='micro')
recall_micro = recall_score(y_true, y_pred, average='micro')
f1_micro = f1_score(y_true, y_pred, average='micro')
precision_macro = precision_score(y_true, y_pred, average='macro')
recall_macro = recall_score(y_true, y_pred, average='macro')
f1_macro = f1_score(y_true, y_pred, average='macro')
print("Micro-averaged Precision Recall \tF1 Score")
print(f'{precision_micro} {recall_micro} \t{f1_micro}')
print("Macro-averaged Precision\t\t\tRecall\t\t\tF1 Score")
print(f'{precision_macro}\t\t\t{recall_macro}\t\t\t{f1_macro}')
print("outPre:", outPre)
print("outLabel", outLabel)
print(f'cont = {cont},total_num = {total_num}')
acc = sum_num.item() / total_num
print('Loss: {:.10f} Acc: {:.4f}'.format(mean_loss.item(),acc))
return acc