对于torch.tensor.scatter()这个函数的理解。

本文详细解读了TensorFlow中的scatter_函数,通过实例演示如何根据index指示在tensor中替换或插入源数据,重点讲解了dim参数对索引选择的影响,并展示了不同维度的实现方式。理解scatter_有助于高效地进行张量操作。
摘要由CSDN通过智能技术生成

Tensor.scatter_(dim, index, src, reduce=None)

确实是比较难以理解的一个函数。

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

感觉像是什么呢 ? 就是src 中的数,按照index中指示的位置,放在tensor中去。dim 会表示index 来表示哪个维度的坐标。

比如 代码: c.scatter(dim=0, index = index, src =src)

下面的代码 一看就懂scatter在干嘛了  。 (这里是dim为1和2的时候 )

index = torch.tensor([[0, 1, 2, 0]])
c = torch.zeros(3, 5, dtype=index.dtype)
# print(c.scatter_(0, b, src))
dim = 0
m = len(index)
n = len(index[0])
for i in range(m):
    for j in range(n):
        new_index = index[i][j]
        if dim == 0:
            c[new_index][j] = src[i][j]
        if dim == 1:
            c[i][new_index] = src[i][j]
print(c)

官网的例子:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

下图是代码运行结果。 

 加一个维度 

index = torch.tensor([[0, 1, 2, 0],
                      [1,2,0,1]])
c = torch.zeros(3, 5, dtype=index.dtype)
# print(c.scatter_(0, b, src))
dim = 0
m = len(index)
n = len(index[0])
for i in range(m):
    for j in range(n):
        new_index = index[i][j]
        if dim == 0:
            c[new_index][j] = src[i][j]
        if dim == 1:
            c[i][new_index] = src[i][j]
print(c)
print(c.scatter(0, index,src))


#
tensor([[1, 0, 8, 4, 0],
        [6, 2, 0, 9, 0],
        [0, 7, 3, 0, 0]])
tensor([[1, 0, 8, 4, 0],
        [6, 2, 0, 9, 0],
        [0, 7, 3, 0, 0]])

当用scatter函数时 也可以用value来代替src  这样所有的值都会被替换成value的值。 比如:

print(c.scatter(0, index,src))
print(c.scatter(0, index,value=1))
tensor([[1, 0, 8, 4, 0],
        [6, 2, 0, 9, 0],
        [0, 7, 3, 0, 0]])

tensor([[1, 0, 1, 1, 0],
        [1, 1, 0, 1, 0],
        [0, 1, 1, 0, 0]])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值