origin label to one hot
- 构造与原始标签维度相同的矩阵
ones = torch.sparse.torch.eye(2)
假设为二分类问题,输出为:
tensor([[1., 0.],
[0., 1.]])
- 根据原始标签填充上述矩阵
例如原始标签为:tar=torch.Tensor([1])
通过ones.index_select(0, tar.long())
处理后结果为:
tensor([[1., 0.]])
注意: tar
需要转化为一维的形式,数据类型为long