关于多分类问题的输出与标签对应问题
import torch
def accuracy(y_hat, y):
return (y_hat.argmax(dim=1) == y).float().mean().item()
x = torch.tensor([[0,1],[1,0],[0,1]])
label = torch.tensor([1,0,1])
#label = torch.tensor([[1],[0],[0]]) #用这个会发生广播
print(x.argmax(dim=1))
print(x.argmax(dim=1
原创
2020-10-17 19:40:24 ·
621 阅读 ·
0 评论