import torch
import torch.nn as nn
output = torch.FloatTensor([[0.91,0,0.9,0,0],
[1.0,0,0.98,0,0.9],
[0,0,0,1.0,0],
[1.0,0,1.0,0,0],
[0.9,1.0,0.98,0.89,0.60]])
_, pred = output.topk(1,1,True,True)
# print(pred)
pred = pred.t()#转置
target = torch.FloatTensor([[0, 0, 2, 2, 1]])
# print(target.view(1, -1))
exp_t = target.view(1, -1).expand_as(pred)#b.expand_as(a)就是将b进行扩充,扩充到a的维度,需要说明的是a的低维度需要比b大,例如b的shape是31,如果a的shape是32不会出错,
correct = pred.eq(exp_t.long())#注意这里的eq函数的参数必须为LongTensor类型,这里利用了long()的方法将其转化为LongTensor类型。
correct_k = correct[:1].view(-1).float().sum(0)
top1_acc = correct_k.mul_(100.0 / 5)
print("top1_acc:",top1_acc)
exp_t = target.view(1,-1).expand_as(pred)
exp_t
tensor([[0., 0., 2., 2., 1.]])print pred
tensor([[0, 0, 3, 0, 1]])pred.shape
(1, 5)pred.eq(exp_t.long()) #相同维度的数据,对应位置是否相等,返回true false
tensor([[ True, True, False, False, True]])
correct.shape 【【】】
(1, 5)correct[0:1].shape #用 ):索引都是维度不是一维,是原数据一样的维度 【【 】】
(1, 5)correct[0].shape 【】
(5,)correct[:1].view(-1) #一个(-1),是不管几维数组都横铺变成一维数据
tensor([ True, True, False, False, True])correct[:1].view(-1).float() #bool变成 float
tensor([1., 1., 0., 0., 1.])correct[:1].view(-1).float().sum(0) #sum(0)行的方向相加 3个1 相加
tensor(3.)correct_k.mul_(100.0/5) #累加和/ 长度(batch_size), *100, 乘除可以交换
tensor(60.)