在学习pytorch的过程中遇到了该方法,该方法是按照行或按照列根据索引去取值。
举例:我们创建一个tensor
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(x)
'''
我们创建了一个3*3的矩阵
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
'''
如果我们想要打印矩阵第一行直接print(x[0])即可,这很简单,如果想打印第二行第三列的值也就是上面矩阵的数字6,我们直接输出x[1][2]即可,也很简单。
如果我想在矩阵的第一列选一个数,第二列选一个数,第三列选一个数,就可以用到gather()方法了。
gather()函数主要有两个参数,第一个参数是规定按列取还是按行去,第二个参数是对于每一行要取的数据的索引:如果我想按照列取,第一列取索引为1的,第二列取索引为2的,第三列取索引为0的,代码如下:
index = torch.tensor([[1, 2, 0]]) # 定义index
dim_0 = x.gather(0, index) # dim:0表示按列取,1表示按行取
'''
结果是3*1的行向量
tensor([[4, 8, 3]])
'''
这里画个图解释一下:
如图,我们要按照列取,那么你的index应该是13的行向量,最后取出来的结果也是13的行向量
如果按照行取,那index应该转为3*1的一个列向量,最后结果应该也是列向量,代码如下:
index = torch.tensor([[1, 2, 0]]) # 定义索引
dim_1 = x.gather(1, index.view(3, 1)) # 按照行取所以第一个参数设为1
# 这里使用view()函数把index转为了3*1的行向量
'''
结果是1*3的列向量
tensor([[2],
[6],
[7]])
'''
图解:
参考文献 :
例解tensor.gather():https://zhuanlan.zhihu.com/p/462008911