官方文档地址:https://pytorch.org/docs/stable/generated/torch.gather.html
这里只讲一下我对input, dim, index这三个参数的理解。
input就是信息源。
dim和index得一起讲。这个函数之所以这么难理解,感觉就在index这个英语单词上。
作为名词的index在中文中叫索引,对于官方文档这个3维tensor的例子,i, j, k就是3维tensor在0, 1, 2三个维度上的索引,通过索引取得element的值。
作为动词的index感觉比较难以描述。python的list有就有index这个成员函数,[3,4,5].index(4)会得到返回值1,index作为动词使用,如index something,我觉得可以理解为给something找到它的index。
已知返回值与index参数有着相同的shape。回到dim这个参数,the axis along which to index,说的就是这个维度(其实也只有这个维度)的索引是需要被找到的,而要找的就是index[i][j][k]这个element,其他维度的索引都相同。对于index参数的每一个element,通过其每个维度每个索引,除了dim维度的索引换做element值以外,其余都不变,在input中找到对应的element,最后得到返回的tensor。
哎,确实不好讲清楚,感觉写了一堆废话。