import torch
'''
场景设定:有3个样本,5个类别
分类器输出的预测结果为,shape为(3,5)
标签值是整数序列:[2,0,3],表示第一个样本属于类别2,第二个样本属于类别0,第三个样本属于类别3
现需要将标签转化为one-hot形式
'''
label = torch.Tensor([[2],[0],[3]]).long() #将label列表转化为列矩阵
oh = torch.zeros(3,5).scatter_(1,label,1) #第一个参数表示按第1维度,即按列进行scatter,第二参数表示索引,第三个参数表示填充的值
print(oh)
#tensor([[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.]])
将标签转化为one-hot
最新推荐文章于 2024-02-29 19:00:00 发布