Pytorch的使用:torch.gather函数

torch.gather()

作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的。

首先看一下官方文档中的3维数据

index 代表输入向量

dim 代表替换的维度

input 代表最终选取的元素

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

接下来我们用一个二维的数据,分别采用官方文档的思路和我个人理解的思路进行简化举例

我们采用了一个3×3的二维矩阵进行练习

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

1.index为行向量,dim = 0 替换行索引

index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

计算思路1:

dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,不发生变化,也为[2,1,0],即取出[(1,0),(2,1),(0,2)]。

计算思路2:

当我们熟悉了计算之后,就可以找到其中的逻辑,而不需要每次都带入计算索引。
dim = 0 代表替换行索引,而输入的是行向量,那我们先将列索引写出,即[0,1,2],然后将行索引替换为index,即[1,2,0],合并后就是最终索引[(1,0),(2,1),(0,2)]。

输出结果

tensor([[4, 8, 3]])

2.index为列向量,dim = 0 替换行索引

index = torch.tensor([[1, 2, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

计算思路1:

dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,即为[0,0,0],合并即取出[(1,0),(2,0),(0,0)]。

计算思路2:

dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以列索引即[0,0,0],然后将行索引替换为index,即[1,2,0],合并后索引为[(1,0),(2,0),(0,0)]。

输出结果

tensor([[4],
        [7],
        [1]])

3.index为行向量,dim = 1 替换列索引

index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

计算思路1:

dim = 1,所以替换列索引,即input[i][ index[i][j] ],可见整个过程就是将列替换,分别是[1,2,0],而行即为index的行,即为[0,0,0],合并即取出[(0,1),(0,2),(0,0)]。

计算思路2:

dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以行索引即[0,0,0],然后将列索引替换为index,即[1,2,0],合并后索引为[(0,1),(0,2),(0,0)]。

输出结果

tensor([[2, 3, 1]])

是不是很简单,相信你已经理解了。

参考文献

https://zhuanlan.zhihu.com/p/352877584 图解PyTorch中的torch.gather函数

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值