from: https://github.com/Sun-DongYang/Pytorch/blob/master/multiLabel/multiLabel.py
多标签计算准确率的方式:按阈值,或者概率最大的前top个标签
# 计算准确率——方式1
# 设定一个阈值,当预测的概率值大于这个阈值,则认为这幅图像中含有这类标签
def calculate_acuracy_mode_one(model_pred, labels):
# 注意这里的model_pred是经过sigmoid处理的,sigmoid处理后可以视为预测是这一类的概率
# 预测结果,大于这个阈值则视为预测正确
accuracy_th = 0.5
pred_result = model_pred > accuracy_th
pred_result = pred_result.float()
pred_one_num = torch.sum(pred_result)
if pred_one_num == 0:
return 0, 0
target_one_num = torch.sum(labels)
true_predict_num = torch.sum(pred_result * labels)
# 模型预测的结果中有多少个是正确的
precision = true_predict_num / pred_one_num
# 模型预测正确的结果中,占所有真实标签的数量
recall = true_predict_num / target_one_num
return precision.item(), recall.item()
# 计算准确率——方式2
# 取预测概率最大的前top个标签,作为模型的预测结果
def calculate_acuracy_mode_two(model_pred, labels):
# 取前top个预测结果作为模型的预测结果
precision = 0
recall = 0
top = 5
# 对预测结果进行按概率值进行降序排列,取概率最大的top个结果作为模型的预测结果
pred_label_locate = torch.argsort(model_pred, descending=True)[:, 0:top]
for i in range(model_pred.shape[0]):
temp_label = torch.zeros(1, model_pred.shape[1])
temp_label[0,pred_label_locate[i]] = 1
target_one_num = torch.sum(labels[i])
true_predict_num = torch.sum(temp_label * labels[i])
# 对每一幅图像进行预测准确率的计算
precision += true_predict_num / top
# 对每一幅图像进行预测查全率的计算
recall += true_predict_num / target_one_num
return precision, recall