torch.gather() 和torch.sactter_()的用法简析

 

torch.gather(input, dim, index, out=None)  和 torch.scatter_(dim, index, src)是一对作用相反的方法

 

先来看torch.gather, 核心操作其实就是这样:

out[i][j][k] = input[index[i][j][k]] [j][k]  # if dim == 0

out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1

out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index

官方文档给的例子是:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]

具体过程就是这里的input = [[1,2],[3,4]], index = [[0,0],[1,0]], dim = 1, 则

out[0][0] = input[0][ index[0][0] ] = input[0][0] = 1

out[0][1] = input[0][ index[0][1] ] = input[0][0] = 1

out[1][0] = input[1][ index[1][0] ] = input[1][1] = 4

out[1][1] = input[1][ index[1][1] ] = input[1][0] = 3

 

 

torch.scatter_(dim, index, src)

核心操作:

self[ index[i][j][k] ][ j ][ k ] = src[i][j][k]  # if dim == 0

self[ i ][ index[i][j][k] ][ k ] = src[i][j][k]  # if dim == 1

self[ i ][ j ][ index[i][j][k] ] = src[i][j][k]  # if dim == 2

这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:

x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

此例中,src就是x,index就是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],  dim=0

我们把src写在左边,把self写在右边,这样好理解一些,

但要注意是把src的值赋给self,所以用箭头指过去:

0.4319 = Src[0][0] ----->self[ index[0][0] ][0] ----> self[0][0]

0.6500 = Src[0][1] ----->self[ index[0][1] ][1] ----> self[1][1]

0.4080 = Src[0][2] ----->self[ index[0][2] ][2] ----> self[2][2]

0.8760 = Src[0][3] ----->self[ index[0][3] ][3] ----> self[0][3]

0.2355 = Src[0][4] ----->self[ index[0][4] ][4] ----> self[0][4]

 

0.2609 = Src[1][0] ----->self[ index[1][0] ][0] ----> self[2][0]

0.4711 = Src[1][1] ----->self[ index[1][1] ][1] ----> self[0][1]

0.8486 = Src[1][2] ----->self[ index[1][2] ][2] ----> self[0][2]

0.8573 = Src[1][3] ----->self[ index[1][3] ][3] ----> self[1][3]

0.1029 = Src[1][4] ----->self[ index[1][4] ][4] ----> self[2][4]

 

则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好

剩下的未得到分配的位置,就填0补充。

### PyTorch 中 `torch.gather` `repeat` 的搭配用法 #### 什么是 `torch.gather`? `torch.gather` 是 PyTorch 中的一个操作,用于从输入张量的不同维度收集指定索引位置上的元素。它允许按照给定的索引来提取特定的数据片段。 其基本语法为: ```python torch.gather(input, dim, index, out=None, sparse_grad=False) ``` 其中: - `input`: 输入张量。 - `dim`: 要沿着哪个维度进行聚集操作。 - `index`: 指定要采集的索引值,形状需与目标一致[^1]。 #### 什么是 `repeat` 或 `.expand()` 方法? `.repeat()` 是一种扩展张量的方法,可以重复某个张量的内容以匹配新的形状。它的作用类似于 NumPy 的 `tile` 函数。对于简单的广播机制不足的情况,可以通过 `.repeat()` 来实现更复杂的形状调整。 其基本语法为: ```python tensor.repeat(*sizes) ``` 参数 `*sizes` 表示各个维度上需要重复的次数[^2]。 --- #### 组合使用场景分析 当我们将 `torch.gather` `repeat` 结合起来时,通常是为了处理以下情况之一: 1. **动态索引选择并复制**:先通过 `gather` 收集某些特定数据,再利用 `repeat` 将这些数据沿某一维度扩展到更大的规模。 2. **批量操作中的灵活映射**:在批处理模式下,针对不同样本执行个性化的特征抽取或变换。 下面是一个具体的例子展示如何组合这两个功能。 --- #### 使用案例 假设有一个二维矩阵 A,以及一组对应每一行的最大值索引 B。我们需要根据索引 B 提取每行最大值,并将其扩展成一个新的三维张量 C。 ##### 示例代码 ```python import torch # 创建一个随机二维张量 A (batch_size=3, feature_dim=4) A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float) # 假设我们知道每行的最大值对应的列索引 B = torch.tensor([3, 2, 1]) # shape: (batch_size,) -> [3, 2, 1] # Step 1: 使用 gather 获取每行的最大值 max_values_per_row = torch.gather(A, 1, B.unsqueeze(1)) # unsqueeze 使 B 变为 (batch_size, 1) print("Max values per row:", max_values_per_row.squeeze()) # Output: tensor([4., 7., 10.]) # Step 2: 扩展结果至更高维空间 C = max_values_per_row.expand(-1, 4).unsqueeze(2) # expand 到 (batch_size, 4), 并增加一维变为 (batch_size, 4, 1) D = C.repeat(1, 1, 3) # 进一步 repeat 成 (batch_size, 4, 3) print(D.shape) # 输出应为 torch.Size([3, 4, 3]) ``` 上述代码展示了如何结合 `torch.gather` `repeat` 完成复杂的数据转换任务。这里的关键在于理解 `gather` 如何定位所需数据,以及 `repeat` 怎样改变张量结构[^3]。 --- #### 注意事项 1. 当调用 `gather` 时,确保索引张量 (`index`) 的尺寸与目标维度相兼容。 2. 对于高阶张量的操作,务必注意各维度的意义及其顺序关系。 3. 如果涉及 GPU 计算,请确认所有参与运算的对象均位于同一设备之上[^4]。 ---
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值