torch_scatter.scatter(src: Tensor, index: Tensor, dim: int = - 1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = 'sum')→ Tensor
-
src – The source tensor. ( 源张量)
-
index – The indices of elements to scatter.(要分散的元素的索引)
-
dim – The axis along which to index. (default:
-1
) (要索引的轴(默认值:-1)) -
out – The destination tensor. (目标张量)
-
dim_size – If
out
is not given, automatically create output with sizedim_size
at dimensiondim
. Ifdim_size
is not given, a minimal sized output tensor according toindex.max() + 1
is returned.(如果未给出out,则在 dim 处自动创建尺寸为 dim_size 的输出。如果没有给出 dim_size,则返回根据 index.max() + 1 的最小尺寸输出张量); -
reduce – The reduce operation (
"sum"
,"mul"
,"mean"
,"min"
or"max"
). (default:"sum"
) (reduce 操作(“ sum”、“ mul”、“ mean”、“ min”或“ max”);
直观的理解:
对于三维矩阵:
y = y.scatter(dim,index,src)
#则结果为:
y[ index[i][j][k] ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ] = src[i][j][k] # if dim == 2
对于二维矩阵:
y = y.scatter(dim,index,src)
#则:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ] = src[i][j] #if dim==1
ps: index的维度,必须和src维度相同;
举例:
>>> src = torch.randn(3, 3)
>>> src
tensor([[-1.8801, 0.9740, 1.2865],
[ 0.3140, 1.2396, -1.3452],
[-0.8937, 0.6916, -2.0134]])
>>> y = y.scatter_(0,index,src)
>>> index = torch.tensor([[0, 1, 0],[1,0,1],[2,1,0]])
>>> index
tensor([[0, 1, 0],
[1, 0, 1],
[2, 1, 0]])
>>> y = y.scatter_(0,index,src)
>>> y
tensor([[-1.8801, 1.2396, -2.0134],
[ 0.3140, 0.6916, -1.3452],
[-0.8937, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]])