标题理解PyTorch中的torch.gather函数
按自己理解记录自己对gather函数的认识,以便回顾
图来源于知乎评论:https://zhuanlan.zhihu.com/p/352877584
1.实例
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
#输出
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
2. index为一维向量
2.1 输入行向量index,并替换行索引(dim=0)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
#输出
tensor([[9, 7, 5]])
dim=0,则确定列数值,第一列为0,第二列为1,以此类推
再将index填入,因为只有一行,故只填第一行,结果为
2.2 输入列向量index,并替换列索引(dim=1)
index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
#输出
tensor([[5],
[7],
[9]])
dim=1,则确定行数值,第一行为0,第二行为1,以此类推
再将index填入,因为只有一列,故只填第一列,结果为
准确来说应该是:
(0,2)
(1,1)
(2,0)