这个函数不好理解,但是在知乎上看到一个言简意赅的解释,特地记录。
例子:
input= torch.arange(3, 12).view(3, 3)
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
output= input.gather(dim=1, index)
解释:
- 首先明确一点,index和output的shape是一样的,即index.shape = output.shape。然后可以确定output = [ [ ?, ?, ?] ]
- 接着把output写成索引形式,即output = [ [ (0,1), (0,2), (0,3) ] ]
- 然后,由于gather中的dim参数是1,我们就把索引形式的output在1这个维度上替换成index的元素,即output 从 [ [ (0,1), (0,2), (0,3) ] ] 变成了 [ [ (0,2), (0,1), (0,0) ] ]
- 然后按照索引取元素,得到output =[ [ 5, 4, 3] ]