在分类问题中,损失函数如果用CrossEntropyLoss,那么需要将label从one-hot转为普通的标签,用以下函数可以实现此功能:
one_hot = torch.tensor([[0,0,1],[0,1,0],[1,0,0]])
print(one_hot)
label = torch.topk(one_hot, 1)[1].squeeze(1)
print(label)
pytorch中将标签从one-hot转为普通label标签(CrossEntropyLoss)
最新推荐文章于 2023-04-13 17:45:36 发布