1--top_k指标的定义
常用的 top_k 指标有 top_1、top_5 和 top_10等;
top_1准确率:前1个概率最高的类别中,包括真实类别的准确率;
top_5准确率:前5个概率最高的类别中,包括真实类别的准确率;
top_10准确率:前10个概率最高的类别中,包括真实类别的准确率;
2--代码实现
import pickle
import pandas as pd
import torch
def top_k(score, label, top_k):
# 对于每一个样本,从小到大排序其类别概率,并从小到大返回其索引值
rank = score.argsort()
# for i, l in enumerate(label) 遍历返回样本 i 对应的真实标签值 l
# l in rank[i, -top_k:] 判断 真实类别l 是否在前 k 个概率最大的类别中
# 前k个概率最大的类别包含真实类别,视为预测正确,记为True
# 前k个概率最大的类别不包含真实类别,视为预测错误,记为False
hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(label)]
# 统计预测正确的数目
result = sum(hit_top_k) * 1.0 / len(hit_top_k)
return result
if __name__ == "__main__":
# score.shape: 样本数 × 类别数
# label.shape: 样本数 × 1
top_k(score, label, top_k = 5) # 计算top_5准确率