函数解析
torch.gather(input,dim,index)
先从二维开始:
>>> a = torch.arange(9).reshape(3,3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index= torch.LongTensor([[2,1,0]]) # 这里index和input必须有一样的维度...
>>> torch.gather(a,0,index)
tensor([[6, 4, 2]])
这里为了简单把index只做了13,如果是23,就类似:
这里dim的意义是在哪个维度上去取值,这意味着:对index和input来说,除了这个维度以外,其他的维度必须一样。
高维的gather
如果维度成为三维或者更高维,也是类似的。