1网上的方法
1.1非零元素位置索引
起初是想找到矩阵里每一行one-hot类型的1所在的位置id.
直接转载吧,懒得整理了:
https://blog.csdn.net/qq_50001789/article/details/120606606
1.2 如何根据标签lable构造one-hot向量呢
https://blog.csdn.net/qq_34914551/article/details/88700334
2 更快速的方法
- 直接进行argmax和用torch.nn.functional中的one_hot方法更快
a = torch.zeros(size=(4,4))
rand = torch.randint(low=0,high=4,size=(4,))
print("one-hot-index:",rand)
x = [i for i in range(4)]
a[x,rand] = 1
#解决问题1
print("find one-hot-index using argmax:",torch.argmax(a,dim=1))
from torch.nn import functional as F
#解决问题2
b = F.one_hot(rand,num_classes=a.size(0))#
print(b)```