top-k概念
该算法主要用于分类任务中,网络output为分类标签的one-hot编码,target为正确标签(一个值)
例如
对一个图像进行分类,类别数为5,该图像标签为3,即target为3,网络output为[0.1, 0.3, 0.25, 0.2, 0.15]
output中最大值下标为2,第二大值下标为3
若只计算top-1,则分类错误;若计算top-2,则分类正确,因为预测值最大的前两个包含正确标签
top-k准确率就是用来计算预测结果中概率最大的前k个结果包含正确标签的占比,平时使用最多的acc即为top-1准确率。
实现代码
逐行解释
— 的 shape 为(batch_size * maxk),pred 的 shape 为(batch_size * maxk)
topk=(1,5) # 预测top-1和top-5
maxk = max(topk) # 按照最大topk值构建张量
batch_size = target.size(0) # 取 batch size
# topk返回两个张量:values和indices,分别对应前k大值的数值和索引
_, pred = output.topk(maxk, 1, True, True)
correct 的 shape 为( maxk * batch_size),内容为bool值
pred = pred.t() # 对 pred 进行转置 maxk*batch_size=3*4
# expand_as将target张量扩展为pred的大小
# eq输出元素相等的布尔值
correct = pred.eq(target.view(1, -1).expand_as(pred))
返回一个装有k个值的数组,每个值对应top-k的值
[correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
整体代码
将整体代码封装成一个函数,可以直接调用
def accuracy(output, target, topk=(1,5)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]