A.scatter_(dim, index, B) # 基本用法, tensor A 被就地scatter到 tensor B
看完一头懵
然后知乎上找了个图torch.scatter_直观理解官网示例 - 知乎
刚开始还是一头懵,后来发现是这个样子的。dim=0表示按行放置,源tensor的第0行第0列元素(也就是0.3992)放在新tensor的第0行(因为index是0), 源tensor的第0行第1列元素(0.2908)放在新tensor的第1行(因为index是1),源tensor的第0行第3列元素(0.9044)放在新tensor的第2行(因为index是2)...以此类推。
dim 和 index
这两个参数是配套的。index和源tensor维度一致(也可以为空,就不改变目标tensor),对于n-D tensor,dim可以为0~N-1。