首先给出pytorch官方定义:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
我们只需要关注input、dim和index三个参数即可(input即被index索引的原始tensor,dim即index中的元素在input的下标中占那个位置,例如有索引a[i][j],当dim=0时,index中的元素占第一个位置,即i的位置。index当然就是input的索引啰。)
然后贴我自己摸索的代码,能看懂的请直接划走!
import torch
a = torch.arange(15).view(3, 5)
print("a:",a,a.shape,a[2][0])
"""
a: tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]) torch.Size([3, 5])
"""
b = torch.zeros_like(a)
# print("b:",b)
b[1][2] = 1
b[0][0] = 2
print("b:",b)
"""
b: tensor([[2, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
"""
c = a.gather(0, b) # dim=0
print("c:",c)
"""
c: tensor([[10, 1, 2, 3, 4],
[ 0, 1, 7, 3, 4],
[ 0, 1, 2, 3, 4]])
dim = 0时,
c = {
[a[2][0],a[0][1],a[0][2],a[0][3],a[0][4]],
[a[0][0],a[0][1],a[1][2],a[0][3],a[0][4]],
[a[0][0],a[0][1],a[0][2],a[0][3],a[0][4]]
}
"""
d = a.gather(1, b) # dim=0
print("d:",d)
"""
d: tensor([[ 2, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
"""
别装了,你看不懂的,快看下面的内容:
c = a.gather(0, b) # dim=0
print("c:",c)
"""
c: tensor([[10, 1, 2, 3, 4],
[ 0, 1, 7, 3, 4],
[ 0, 1, 2, 3, 4]])
dim = 0时,
c = {
[a[2][0],a[0][1],a[0][2],a[0][3],a[0][4]],
[a[0][0],a[0][1],a[1][2],a[0][3],a[0][4]],
[a[0][0],a[0][1],a[0][2],a[0][3],a[0][4]]
}
"""
按照我的理解(不一定对哈),dim=0时,张量b中的元素填入a的对应位置的第0维索引,如下图:
上图中橙色箭头指向的即是索引与输入的对应关系,剩余的请自己摸索。