top-k概念
该算法主要用于分类任务中,网络output为分类标签的one-hot编码,target为正确标签(一个值)
例如
对一个图像进行分类,类别数为5,该图像标签为3,即target为3,网络output为[0.1, 0.3, 0.25, 0.2, 0.15]
output中最大值下标为2,第二大值下标为3
若只计算top-1,则分类错误;若计算top-2,则分类正确,因为预测值最大的前两个包含正确标签
top-k准确率就是用来计算预测结果中概率最大的前k个结果包含正确标签的占比,平时使用最多的acc即为top-1准确率。
实现代码
逐行解释
— 的 shape 为(batch_size * maxk),pred 的 shape 为(batch_size * maxk)
topk=(1,5) # 预测top-1和top-5
maxk = max(topk) # 按照最大topk值构建张量
batch_size = target.size(0) # 取 batch size
# topk返回两个张量:values和indices,分别对应前k大值的数值和索