torch.Tensor.scatter_()
是torch.gather()
函数的方向反向操作。两个函数可以看成一对兄弟函数。gather
用来解码one hot,scatter_
用来编码one hot。
scatter_(dim, index, src) → Tensor
dim (python:int)
– 用来寻址的坐标轴index (LongTensor)
– 索引src(Tensor)
–用来scatter的源张量,以防value未被指定。value(python:float)
– 用来scatter的源张量,以防src未被指定。
现在我们来看看具体这么用,看下面这个例子就一目了然了。
- dim =0
import torch
x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
[0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0