标题Pytorch_scatter中的scatter包
例子:
解释:
将src按index的指示,在维度dim上进行修改。
1、关于元素位置
index=[0,0,1]中第一个0表示将第dim=1维度中的下标为0的元素[1.0226,-0.3013]放在输出的下标为0的位置;第二个0表示将下标为1的元素[-0.1796,-0.4600]放在输出下标为0的位置;同理第三个1表示将下标为3的放在输出1的位置,具体如下图:
由于第一行和第二行都换到了输出的相同维度,因此我们对这两行执行reduce函数操作(默认为sum)
2、关于输出维度
dim=1表示修改原维度[1,3,2]为[1,max[0,0,1]+1,2],即[1,2,2]。(注:输出维度也可利用out及dim_size参数指定)
3、源注释
以下是官方文档对于scatter函数的解释,供大家学习时参考。
Parameters :
src – The source tensor.
index – The indices of elements to scatter.
dim – The axis along which to index. (default: -1)
out – The destination tensor.
dim_size – If out is not given, automatically create output with size dim_size at dimension dim. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.
reduce – The reduce operation ("sum", "mul", "mean", "min" or "max"). (default: "sum")