#pytorch 向量转化为one-hot编码
import torch
#原始向量
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
#结果
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
pytorch 向量转化为one-hot编码
最新推荐文章于 2022-04-22 20:25:11 发布