主要函数:
scatter_(dim, index, src)
dim:维度,表示在第几维上操作;
index:索引,后面再解释;
src:用来填充的tensor。
例子:
label 标签转换为 one-hot
label = torch.tensor([[1], [2]]) # 类别从0开始
# label: tensor([[1], [2]])
output= torch.zeros(2, 4)
# tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
output.scatter_(1, label, 1)
# tensor([[0., 1., 0., 0.],
# [0., 0., 1., 0.]])
在特定的维度(index)转换成指定的值(src)
src = torch.tensor([[1.000], [2.000]])
# src: tensor([[1.], [2.]])
index = torch.tensor([[1], [2]])
# index: tensor([[1], [2]])
output= torch.zeros(2, 4)
# tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
output.scatter_(1, index, src)
# tensor([[0., 1., 0., 0.],
# [0., 0., 2., 0.]])
参考:pytorch:scatter_函数生成one_hot向量