torch.gather(源张量,维度轴dim,索引张量)
gather()是类似于数组按下标获取元素值的方法,只不过数组或者二维数组可以直接通过行列下标获取值, 而张量一般都是多维度的,不可以用下标获取,需要借助维度轴(dim)。
源张量指的就是原始张量,维度轴如图
索引张量的样式要与目标张量保持一致。
例:
t = torch.Tensor([[1,2],[3,4]])#创建一个2*2的浮点张量,这是一个二维的张量,存在dim=0和dim=1,当dim=0时,即纵轴,此时索引张量里面的数字即为行下标(从0开始),
该数字所在索引张量的列即源张量的列。当dim=1时,即横轴,此时索引张量里面的数字即为列下标(从0开始),该数字所在索引张量的行即源张量的行。
1 | 2 |
3 | 4 |
下面用两个索引张量例证
torch.tensor([[0,0]])#索引张量,假如dim=1。第一个0有两个意思,0本身是代表列下标为0,此0所在索引张量的第一行意味着代表源张量的第一行也就是源张量行下标为0,所以指向的数字就是源张量的1;同理,第二个0也有两个意思,0本身代表源张量列下标为0,此0所在索引张量的第一行意味着代表源张量的第一行也就是源张量行下标为0,所以指向的数字就是源张量的1.与此同时,索引张量的形式为1*2,那么最后得出的目标张量也是1*2.
print(torch.gather(t,1,torch.tensor([[0,0]])))
>>> tensor([[1., 1.]])#这是输出结果。
torch.tensor([[1,0],[1,1]])#假如dim=1.第一个1代表源张量列下标为1,所在索引张量第一行,也就是源张量行下标为0,最后所指向的数字为2。第二个1代表源张量列下标为1,所在索引张量第二行,也就是源张量行下标为1,最后
所指向的数字就是源张量的4.其他同理
>>> tensor([[2., 1.], [4., 4.]]#此为输出。