top-k准确率计算

top-k概念

该算法主要用于分类任务中,网络output为分类标签的one-hot编码,target为正确标签(一个值)
例如
对一个图像进行分类,类别数为5,该图像标签为3,即target为3,网络output为[0.1, 0.3, 0.25, 0.2, 0.15]
output中最大值下标为2,第二大值下标为3
若只计算top-1,则分类错误;若计算top-2,则分类正确,因为预测值最大的前两个包含正确标签

top-k准确率就是用来计算预测结果中概率最大的前k个结果包含正确标签的占比,平时使用最多的acc即为top-1准确率。

实现代码

逐行解释

— 的 shape 为(batch_size * maxk),pred 的 shape 为(batch_size * maxk)

topk=(1,5)  # 预测top-1和top-5
maxk = max(topk) # 按照最大topk值构建张量
batch_size = target.size(0)  # 取 batch size
# topk返回两个张量:values和indices,分别对应前k大值的数值和索引
_, pred = output.topk(maxk, 1, True, True)

correct 的 shape 为( maxk * batch_size),内容为bool值

pred = pred.t() # 对 pred 进行转置 maxk*batch_size=3*4
# expand_as将target张量扩展为pred的大小
# eq输出元素相等的布尔值
correct = pred.eq(target.view(1, -1).expand_as(pred))

返回一个装有k个值的数组,每个值对应top-k的值

[correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

整体代码

将整体代码封装成一个函数,可以直接调用

def accuracy(output, target, topk=(1,5)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值