scatter_(input, index, dim):将src中数据根据index中的索引按照dim的方向填进input。
input: 填入的值
index:填在那个元素位置的索引(二维)
dim:沿着哪个维度进行索引填充
代码:
import torch
class_num = 10
batch_size = 5
label = torch.LongTensor([0,1,4,2,3])
print(label)
one_hot = torch.zeros(batch_size, class_num).scatter_(1, label.reshape(-1,1), 1)
print(one_hot)
输出:
# tensor([0, 1, 4, 2, 3])
# tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])
#
# Process finished with exit code 0