pytorch 进行one-hot编码的两种方法
torch.nn.functional 中的ont_hot方法
target = torch.tensor([0,1,2,3,4])
target = F.one_hot(target,num_classes = 5)
tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]])
使用scatter_函数
target = torch.tensor([0,2,3,4,1])
index = target.view(-1,1).long()
target = torch.zeros(target.size(0),5,dtype=torch.float32)
target = target.scatter_(1,index,1)
tensor([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.]])