Tensor.
scatter_
(dim, index, src, reduce=None)
确实是比较难以理解的一个函数。
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
感觉像是什么呢 ? 就是src 中的数,按照index中指示的位置,放在tensor中去。dim 会表示index 来表示哪个维度的坐标。
比如 代码: c.scatter(dim=0, index = index, src =src)
下面的代码 一看就懂scatter在干嘛了 。 (这里是dim为1和2的时候 )
index = torch.tensor([[0, 1, 2, 0]])
c = torch.zeros(3, 5, dtype=index.dtype)
# print(c.scatter_(0, b, src))
dim = 0
m = len(index)
n = len(index[0])
for i in range(m):
for j in range(n):
new_index = index[i][j]
if dim == 0:
c[new_index][j] = src[i][j]
if dim == 1:
c[i][new_index] = src[i][j]
print(c)
官网的例子:
>>> src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]])
下图是代码运行结果。
加一个维度
index = torch.tensor([[0, 1, 2, 0],
[1,2,0,1]])
c = torch.zeros(3, 5, dtype=index.dtype)
# print(c.scatter_(0, b, src))
dim = 0
m = len(index)
n = len(index[0])
for i in range(m):
for j in range(n):
new_index = index[i][j]
if dim == 0:
c[new_index][j] = src[i][j]
if dim == 1:
c[i][new_index] = src[i][j]
print(c)
print(c.scatter(0, index,src))
#
tensor([[1, 0, 8, 4, 0],
[6, 2, 0, 9, 0],
[0, 7, 3, 0, 0]])
tensor([[1, 0, 8, 4, 0],
[6, 2, 0, 9, 0],
[0, 7, 3, 0, 0]])
当用scatter函数时 也可以用value来代替src 这样所有的值都会被替换成value的值。 比如:
print(c.scatter(0, index,src))
print(c.scatter(0, index,value=1))
tensor([[1, 0, 8, 4, 0],
[6, 2, 0, 9, 0],
[0, 7, 3, 0, 0]])
tensor([[1, 0, 1, 1, 0],
[1, 1, 0, 1, 0],
[0, 1, 1, 0, 0]])