功能:把source矩阵的值,根据索引“撒”到目标矩阵中。
格式:torch.scatter(dim, index, source)
dim:维度
index:对该维度的索引
source:数据来源
例子:
x = torch.rand(2, 5)
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
解释:
数据源头是x,x有10个值,现在把这10个值撒到[3, 5]的矩阵中,那么每个值都要有一个新的位置索引,这个新的索引由index指定。
首先,有10个坑位:
然后把index写进去,dim=0,表示index代表第0维;
0 | 1 | 2 | 0 | 0 |
2 | 0 | 0 | 1 | 2 |
最后,按照自然顺序补充第二维索引
0 0 | 1 1 | 2 2 | 0 3 | 0 4 |
2 0 | 0 1 | 0 2 | 1 3 | 2 4 |
以第一行22为例,表示把x中[0, 2]的数据【0.8184】,路由到目标矩阵的[2, 2]位置。