我终于想明白这个函数的意思啦!!!
scatter的意思:服务组件架构
data.scatter_(dim,index,src)
将src中数据根据index中的索引按照dim的方向填进data。
import torch
x=torch.tensor([[7, 2, 3,4, 5],
[1.1, 2.2, 3.3, 4.4, 5.5]])
y=torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
print(x)
z=torch.zeros(3,5).scatter_(0,y,x)
print(z)
#输出
'''
tensor([[7.0000, 2.0000, 3.0000, 4.0000, 5.0000],
[1.1000, 2.2000, 3.3000, 4.4000, 5.5000]])
tensor([[7.0000, 2.2000, 3.3000, 4.0000, 5.0000],
[0.0000, 2.0000, 0.0000, 4.4000, 0.0000],
[1.1000, 0.0000, 3.0000, 0.0000, 5.5000]])
'''
y=torch.zeros(3,5).scatter_(0,y,x)
这句话的意思是:
首先:创建一个(3,5)的〇为z,然后,将x中的值补充到〇中。
根据z进行补充,如果若干个数字填充到相同的位置,则按顺序,看最后是什么,则确定是什么。
[[0,1,2,0,0],[2,0,0,1,2]]
(0,0),(0,1),(0,2),(0,3),(0,4)
(1,0),(1,1),(1,2),(1,3),(1,4)
这是x数据的位置,这时候,要x数据的位置坐标,根据y更改,由于dim=0,即更改第0维度
(0,0),(1,1),(2,2),(0,3),(0,4)
(2,0),(0,1),(0,2),(1,3),(2,4)
x数据在z中位置就是上面坐标。如果有几个数据对应的坐标相同,则一层层覆盖,显示最后一个数据。
import torch
x=torch.tensor([[7, 2, 3,4, 5],
[1.1, 2.2, 3.3, 4.4, 5.5]])
y=torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
print(x)
z=torch.zeros(3,5).scatter_(1,y,x)
print(z)
#输出
'''
tensor([[7.0000, 2.0000, 3.0000, 4.0000, 5.0000],
[1.1000, 2.2000, 3.3000, 4.4000, 5.5000]])
tensor([[5.0000, 2.0000, 3.0000, 0.0000, 0.0000],
[3.3000, 4.4000, 5.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
'''
(0,0),(0,1),(0,2),(0,3),(0,4)
(1,0),(1,1),(1,2),(1,3),(1,4)
这是x数据的位置,这时候,要x数据的位置坐标,根据y=[[0,1,2,0,0],[2,0,0,1,2]]更改,由于dim=1,即更改第1维度
(0,0),(0,1),(0,2),(0,0),(0,0)
(1,2),(1,0),(1,0),(1,1),(1,2)
三个数字的新坐标都是(0,0),则以最后一次也就是x中数据5占领此位置。