pytorch 分类预测接口topk API使用

import torch
import torch.nn as nn

output = torch.FloatTensor([[0.91,0,0.9,0,0],
							[1.0,0,0.98,0,0.9],
							[0,0,0,1.0,0],
							[1.0,0,1.0,0,0],
							[0.9,1.0,0.98,0.89,0.60]])
_, pred = output.topk(1,1,True,True)
# print(pred)
pred = pred.t()#转置
target = torch.FloatTensor([[0, 0, 2, 2, 1]])
# print(target.view(1, -1))
exp_t = target.view(1, -1).expand_as(pred)#b.expand_as(a)就是将b进行扩充,扩充到a的维度,需要说明的是a的低维度需要比b大,例如b的shape是31,如果a的shape是32不会出错,
correct = pred.eq(exp_t.long())#注意这里的eq函数的参数必须为LongTensor类型,这里利用了long()的方法将其转化为LongTensor类型。
correct_k = correct[:1].view(-1).float().sum(0)
top1_acc = correct_k.mul_(100.0 / 5)
print("top1_acc:",top1_acc)

exp_t = target.view(1,-1).expand_as(pred)
exp_t
tensor([[0., 0., 2., 2., 1.]])

print pred
tensor([[0, 0, 3, 0, 1]])

pred.shape
(1, 5)

pred.eq(exp_t.long()) #相同维度的数据,对应位置是否相等,返回true false
tensor([[ True, True, False, False, True]])

correct.shape 【【】】
(1, 5)

correct[0:1].shape #用 ):索引都是维度不是一维,是原数据一样的维度 【【 】】
(1, 5)

correct[0].shape 【】
(5,)

correct[:1].view(-1) #一个(-1),是不管几维数组都横铺变成一维数据
tensor([ True, True, False, False, True])

correct[:1].view(-1).float() #bool变成 float
tensor([1., 1., 0., 0., 1.])

correct[:1].view(-1).float().sum(0) #sum(0)行的方向相加 3个1 相加
tensor(3.)

correct_k.mul_(100.0/5) #累加和/ 长度(batch_size), *100, 乘除可以交换
tensor(60.)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值