torch.gather()函数可以理解为根据索引和维度来求张量中对应的数,最后得到的是一个shape和index相同的张量
即“以A = B.gather(dim=0, index=torch.tensor([[2, 1, 2]]))为例,首先确定A的维度与index维度一致(index维度可以是任意的维度,不要受限于B),即A的维度为(1,3);其次dim=0代表按列索引,那么index第一个元素“2”的含义为在B中其所在列(即第0列)的第2个元素。同理,index第二个元素“1”的含义为在B中其所在列(即第1列)的第1个元素;index第三个元素“2”的含义为在B中其所在列(即第2列)的第2个元素。
如果是A = B.gather(dim=1, index=torch.tensor([[2, 1, 2]])),那么按行索引,而[[2, 1, 2]]本身就是行向量,故“2”“1”“2”都代表的是B中第0行的对应列数的元素,将它们拿出来,即组成A”
例如
tensor是一个3行3列的张量
[ 3, 4, 5,
6, 7, 8,
9, 10, 11]
行索引和列索引![](https://i-blog.csdnimg.cn/blog_migrate/8a639bfbccbd5c9a92b3ec33f20d7441.png)
二维索引