PyTorch scatter_ 从懵懂到清晰
今天看pytorch官方文档,看到一个函数scatter_,好奇了一下去看官方解释,乍一看寥寥几句解释,却把我整懵了,赶紧搜了一下,发现好几篇文章都是写的差不多,也没看懂,后来终于看到这篇文章,才算搞明白了是咋回事 原文链接。
重点来了老弟
直接贴图:
假设我们有这样一个tensor
调用scatter_方法,会得到如下结果:
(3,5) 表示输出的tensor的维度,初始值都为0;
0 表示按行这一维度填充,暂用dim表示,1表示按列填充;
[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]] 表示索引,暂用index表示;
x 就是第一个图的输入了。
执行原理是这样的:
可以将x与index一一对应起来,那么函数的意思就是:将与index中每个位置对应的x中的值填充到输出b中index中该位置值对应行中,列保持不变,感觉说不太明白omg。
举个栗子:
- index中索引(0,0)的值0对应x中(0,0)的0.0862,那意思就是将0.0862填充到输出的第0行中,列保持不变,即填充到b中的(0,0)位置;
- index中索引(0,1)的值1对应x中(0,1)的0.6349,那意思就是将0.6349填充到输出的第1行中,列保持不变,即填充到b中的(1,1)位置;
- index中索引(1,3)的值1对应x中(1,3)的0.7461,那意思就是将0.7461填充到输出的第1行中,列保持不变,即填充到b中的(1,3)位置。
这个函数执行原理就是这样了。
如果将dim的值改为1,那么就是保持行不变,根据index中的值填充到对应的列中,同样的x作为输入,dim改为1后的输出为:
可以在纸上将这个过程写一下会更清晰。
附两个图