关于多分类问题的输出与标签对应问题
import torchdef accuracy(y_hat, y): return (y_hat.argmax(dim=1) == y).float().mean().item()x = torch.tensor([[0,1],[1,0],[0,1]])label = torch.tensor([1,0,1])#label = torch.tensor([[1],[0],[0]]) #用这个会发生广播print(x.argmax(dim=1))print(x.argmax(dim=1
原创
2020-10-17 19:40:24 ·
619 阅读 ·
0 评论