pytorch中scatter()与scatter_()函数的用法与区别

Tensor.scatter_函数用于根据index在指定维度dim上将src的值写入self中,要求self和src的dtype相同。当index有重叠时,后面的赋值会覆盖前面的。scatter_是原地修改self的,而scatter则返回新Tensor。例如,当index有重复时,最后的值会覆盖之前的值,导致self中某些位置的值改变。
摘要由CSDN通过智能技术生成

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
其作用是根据index将src中的值写到self中, dim决定了维度
这里需要注意的一点是self的dtype要和src的dtype相同!!!例如:

torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

这里的self的dtype要和src的dtype相同。
函数的作用以3D的tensor举例子:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

举个具体的例子:

src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 2]])
# tensor([[0, 1, 2, 0],
#		  [1, 0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 7, 0, 4, 0],
#         [6, 2, 8, 0, 0],
#         [0, 0, 3, 9, 0]])

# 分析:index的i取值为0-1,j的取值从0-3都可以
# self[index[0][0]][0] = self[0][0] = src[0][0] = 1
# self[index[0][1]][1] = self[1][1] = src[0][1] = 2
# self[index[0][2]][2] = self[2][2] = src[0][2] = 3
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# self[index[1][0]][0] = self[1][0] = src[1][0] = 6
# self[index[1][1]][1] = self[0][1] = src[1][1] = 7
# self[index[1][2]][2] = self[1][2] = src[1][2] = 8
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9

这里还有个有意思的事情, 上面的情况是没有重叠的情况,假设index的上下两行中有重叠的元素,比如

index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 0]])

注意第一行的最后一个元素与第二行的最后一个元素相同了, 都为0。(之前第二行最后一个元素为2)
这样的话上面的取值

# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9
变为了
# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[0][3] = src[1][3] = 9

可以看到self[0][3]有了2个赋值,一次是根据i=0,j=3所赋的4;另一次是根据i=1,j=3所赋的9;根据前后顺序关系,9会把4个给覆盖掉,因此最终得到的结果变为:

tensor([[1, 7, 0, 9, 0],
        [6, 2, 8, 0, 0],
        [0, 0, 3, 0, 0]])

scatter()与scatter_()的区别在于scatter_()是原地操作的。
举例,b = a.scatter(dim, index, src)后a的值不会发生变化
相对的, b = a.scatter_(dim, index, src)后a的值发生变化, 变得与b相等

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值