tensor.gather()使用

文章介绍了在PyTorch中使用gather()函数进行张量操作,包括按行或列选择特定索引值,以及如何根据需求转换索引。通过例子展示了如何获取矩阵的指定行或列数据。
摘要由CSDN通过智能技术生成

在学习pytorch的过程中遇到了该方法,该方法是按照行或按照列根据索引去取值。

举例:我们创建一个tensor

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(x) 
''' 
我们创建了一个3*3的矩阵
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
'''

如果我们想要打印矩阵第一行直接print(x[0])即可,这很简单,如果想打印第二行第三列的值也就是上面矩阵的数字6,我们直接输出x[1][2]即可,也很简单。
如果我想在矩阵的第一列选一个数,第二列选一个数,第三列选一个数,就可以用到gather()方法了。

gather()函数主要有两个参数,第一个参数是规定按列取还是按行去,第二个参数是对于每一行要取的数据的索引:如果我想按照列取,第一列取索引为1的,第二列取索引为2的,第三列取索引为0的,代码如下:

index = torch.tensor([[1, 2, 0]]) # 定义index
dim_0 = x.gather(0, index) # dim:0表示按列取,1表示按行取
'''
结果是3*1的行向量
tensor([[4, 8, 3]])
'''

这里画个图解释一下:
在这里插入图片描述
如图,我们要按照列取,那么你的index应该是13的行向量,最后取出来的结果也是13的行向量

如果按照行取,那index应该转为3*1的一个列向量,最后结果应该也是列向量,代码如下:

index = torch.tensor([[1, 2, 0]]) # 定义索引
dim_1 = x.gather(1, index.view(3, 1)) # 按照行取所以第一个参数设为1
# 这里使用view()函数把index转为了3*1的行向量
'''
结果是1*3的列向量
tensor([[2],
        [6],
        [7]])
'''

图解:
在这里插入图片描述
参考文献 :
例解tensor.gather():https://zhuanlan.zhihu.com/p/462008911

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值