torch.Tensor.scatter 有 4 个参数:
scatter(dim, index, src, reduce=None)
先忽略 Reduce,最后再解释。先从最简单的开始。我们有一个 (2,4) 形状的张量,里面填充了 1:
粉红色的符号表示张量结构
并且我们传入相应的参数并得到输出:
注意index张量结构
现在我们增加index张量的第二个值,并比较输出:
观察数字 6 在output张量中的移动情况
好的,数字 6 由index张量内的第二个值控制。但是,如何控制呢?
以下是幕后发生的事情。
首先,我们将index形状扩展为与 src 相同的形状:
它实际上不需要扩展。但这将有助于我们理解
如果 index 中有值,则从 src 中提取相应的值。
如果没有值,则不执行任何操作。
这里有 0 和 3,因此提取 5 和 6:
这意味着index的结构必须是 src 的子结构。否则,你将收到错误: