Tensor.scatter_(dim, index, src, reduce=None) → Tensor
其作用是根据index将src中的值写到self中, dim决定了维度
这里需要注意的一点是self的dtype要和src的dtype相同!!!例如:
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
这里的self的dtype要和src的dtype相同。
函数的作用以3D的tensor举例子:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
举个具体的例子:
src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0],
[1, 0, 1, 2]])
# tensor([[0, 1, 2, 0],
# [1, 0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 7, 0, 4, 0],
# [6, 2, 8, 0, 0],
# [0, 0, 3, 9, 0]])
# 分析:index的i取值为0-1,j的取值从0-3都可以
# self[index[0][0]][0] = self[0][0] = src[0][0] = 1
# self[index[0][1]][1] = self[1][1] = src[0][1] = 2
# self[index[0][2]][2] = self[2][2] = src[0][2] = 3
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# self[index[1][0]][0] = self[1][0] = src[1][0] = 6
# self[index[1][1]][1] = self[0][1] = src[1][1] = 7
# self[index[1][2]][2] = self[1][2] = src[1][2] = 8
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9
这里还有个有意思的事情, 上面的情况是没有重叠的情况,假设index的上下两行中有重叠的元素,比如
index = torch.tensor([[0, 1, 2, 0],
[1, 0, 1, 0]])
注意第一行的最后一个元素与第二行的最后一个元素相同了, 都为0。(之前第二行最后一个元素为2)
这样的话上面的取值
# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9
变为了
# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[0][3] = src[1][3] = 9
可以看到self[0][3]有了2个赋值,一次是根据i=0,j=3所赋的4;另一次是根据i=1,j=3所赋的9;根据前后顺序关系,9会把4个给覆盖掉,因此最终得到的结果变为:
tensor([[1, 7, 0, 9, 0],
[6, 2, 8, 0, 0],
[0, 0, 3, 0, 0]])
scatter()与scatter_()的区别在于scatter_()是原地操作的。
举例,b = a.scatter(dim, index, src)后a的值不会发生变化
相对的, b = a.scatter_(dim, index, src)后a的值发生变化, 变得与b相等