import torch
# 假设输出和目标
output = torch.tensor([
[0.1, 0.2, 0.6, 0.1], # 第一个样本的预测分数
[0.1, 0.4, 0.2, 0.3], # 第二个样本的预测分数
[0.5, 0.1, 0.2, 0.2] # 第三个样本的预测分数
])
target = torch.tensor([2, 1, 0]) # 真实的标签
# 定义准确率计算函数
def accuracy(output, target, topk):
maxk = max(topk)
_,pred = output.topk(maxk, 1, True, True)
# pred = pred.t()
# print(pred)
print(target.view(1,-1).expand_as(pred))
# 计算准确率
accuracy(output, target, topk=(1,3))
理论学习:expand_as
最新推荐文章于 2024-07-07 12:38:54 发布