pytorch 深入理解 tensor.scatter_ ()用法

pytorch 深入理解 tensor.scatter_ ()用法

在 pytorch 库下理解 torch.tensor.scatter()的用法。作者在网上搜索了很多方法,最后还是觉得自己写一篇更为详细的比较好,转载请注明。
首先,scatter() 和 scatter_() 的作用是一样的,但是 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原先的 Tensor。

1 API格式

torch.Tensor.scatter_(dim, index, src) → Tensor

字面意思:对一个 torch.Tensor 进行操作,dim,index,src三个为输入的参数。

  • dim 就是在哪个维度进行操作,注意,dim 的不同,在其他条件相同的条件下得到的output 也不同。
  • index 是输入的索引。
  • src 就是输入的向量,也就是 input。

最后,函数返回一个 Tensor。

2 具体示例

import torch as th
# import torch 包

a = th.rand(2,5)	
# 初始化向量 a,size 为 (2, 5),二维向量,2行5列,每个元素是 0 到 1 的均匀分布采样
# 把 a 作为 src,也就是 input 
# a 的初始化数值如下: 
src tensor:
tensor([[0.6789, 0.7350, 0.6104, 0.7777, 0.9613],
        [0.1432, 0.8788, 0.3269, 0.0063, 0.6070]])

        
# 初始化 b 为size 为 (3, 5) 的向量,二维向量,3行5列,每个元素被初始化为 0
b = th.zeros(3, 5).scatter_(
	dim = 0,
	index = th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),
	src = a 
)
# dim = 0, out:
tensor([[0.6789, 0.8788, 0.3269, 0.7777, 0.9613],
        [0.0000, 0.7350, 0.0000, 0.0063, 0.0000],
        [0.1432, 0.0000, 0.6104, 0.0000, 0.6070]])
        
# 初始化 c 为size 为 (3, 5) 的向量,二维向量,3行5列,每个元素被初始化为 0
c = th.zeros(3, 5).scatter_(
	dim = 1,
	th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),
	src = a	
)
# dim = 1, out:
tensor([[0.9613, 0.7350, 0.6104, 0.0000, 0.0000],
        [0.3269, 0.0063, 0.6070, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

下面来解释一下,b,c 内的元素分别是怎么得到的。

2.1 dim = 0 下的结果分析

先说 b,也就是 dim =0 下得到的结果。我们来看下官方给的说明文字:

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

因为这时 dim = 0,且只有 2 个维度,所以我们只用看第一行就行。

self [index[i][j]] [j] = src[i][j] # if dim == 0

仅用这一个公式就确定了 b 中所有元素的取值,与 a 的映射关系。这里等号左边的 self 可看做 output,也就是 b;src 是我们的输入向量,也就是 a。这里的 i,j 分别是输入向量 src 的 size 的取值。比如,本例中 a 的 size 为 (2,5),也就是说,对于 a 中的元素,i 的取值为 0,1;j 的取值为 0,1,2,3,4。a 中的元素的索引也就是(0,0),(0,1),… (0,4);(1,0),(1,1),…(1,4) 完毕,一共 2*5 = 10 个元素。
了解了这些以后,通过举例来说明 b 中的元素都是如何确定的。

index = th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]),
我们列举一些元素来说明其映射关系

当 i = 0,j = 0 时,
我们用类似上述确定 a 索引的方式确定了 index[i][j] = 0,
这里的 0 就是 [0,1,2,0,0] 中最左边的 0,
则 b = out[index[i][j]][j] = out[0][0] = src[0][0] = 0.6789

当 i = 0,j = 1 时,index[0][1] = 1,
这里的 1 就是 [0,1,2,0,0] 中的 1,
同理,b = out[index[i][j]][j] = out[1][1] = src[0][1] = 0.7350

当 i = 0,j = 2 时,index[0][2] = 2,
这里的 2 就是 [0,1,2,0,0] 中的 2,
同理,b = out[index[i][j]][j] = out[2][2] = src[0][2] = 0.6104
注意,这里的out[2][2] 不是第 2 行,第 2 列的元素,是第 3 行,第 3 列的元素

当 i = 1,j = 1 时,index[1][1] = 0,
这里的 0 就是 [2,0,0,1,2] 中最**左**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][1] = 0.8788

当 i = 1,j = 3 时,index[1][3] = 0,
这里的 0 就是 [2,0,0,1,2] 中最**右**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][3] = 0.0063

当 i = 1,j = 4 时,index[1][4] = 2,
这里的 2 就是 [2,0,0,1,2] 中最**左**边的 0,
同理,b = out[index[i][j]][j] = out[0][1] = src[1][4] = 0.6070

由此得到了 b 中有映射关系的元素,剩余的元素,由于 b 被初始化为全 0 向量,所以剩余的元素均为 0 。

dim = 1的时候,同理。只是换了一种映射机制,如法炮制。

有任何关于内容不够详细,解释不清,错误等欢迎留言。转载请注明,支持原创,谢谢。

  • 25
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
`torch_scatter.scatter_max`函数是PyTorch中的一种scatter函数,用于将输入的Tensor按照指定的维度进行散射操作,并返回指定维度上的元素最大值和对应的索引位置。 该函数的输入包括三个参数:输入Tensor(即要进行散射操作的Tensor)、散射维度dim和索引Tensor(即指定维度上的索引位置)。输出包括两个Tensor:散射后的Tensor和对应的最大值和索引位置。 具体来说,`torch_scatter.scatter_max`函数的操作流程如下: 1. 根据索引Tensor将输入Tensor按照指定维度进行散射操作,得到一个散射后的Tensor。 2. 在指定维度上找到散射后的Tensor中的最大值和对应的索引位置。 3. 返回散射后的Tensor和最大值和索引位置对应的两个Tensor。 值得注意的是,如果输入Tensor中某些元素在指定维度上对应的索引位置相同,那么在散射操作时,这些元素的最大值和索引位置会被更新为最后一个被处理到的元素的最大值和索引位置。 下面是一个简单的示例代码,演示了如何使用`torch_scatter.scatter_max`函数: ```python import torch from torch_scatter import scatter_max # 定义一个输入Tensor x = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) # 定义一个索引Tensor index = torch.tensor([0, 1, 0]) # 在第一维上进行散射操作,得到散射后的Tensor和最大值和索引位置对应的两个Tensor out, argmax = scatter_max(x, index, dim=0) # 输出结果 print(out) # tensor([[0.7000, 0.8000, 0.9000], [0.4000, 0.5000, 0.6000]]) print(argmax) # tensor([2, 1]) ``` 在上面的示例代码中,我们首先定义了一个3x3的输入Tensor `x`,然后定义了一个长度为3的索引Tensor `index`,表示在第一维上,第一个元素要被散射到第0个位置,第二个元素要被散射到第1个位置,第三个元素要被散射到第0个位置。 之后我们调用`torch_scatter.scatter_max`函数,在第一维上进行散射操作,得到了散射后的Tensor `out`和最大值和索引位置对应的两个Tensor `argmax`。最后我们输出了这两个Tensor的值,可以看到在第一维上,第一个位置对应的最大值为0.7,索引为2,第二个位置对应的最大值为0.5,索引为1。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值