tensor.gather()的使用(pytorch)
tensor.gather(dim, indexs)
功能: 在dim维度上,按照indexs所给的坐标选择元素,返回一个和indexs维度相同大小的tensor。
它和torch.gather功能是一样的。
torch.gather()官方文档
注意: 这里indexs必须也是Tensor,并且维度数与input相同(len(input.shape)=len(indexs.shape))
例子:
>>>import torch
>>>a = torch.Tensor([[1,2,3,4,5,6],[0,1,2,3,4,5]])
>>>a
Out[4]:
tensor([[1., 2., 3., 4., 5., 6.],
[0., 1., 2., 3., 4., 5.]])
>>>
>>>b = a.gather(1, torch.tensor([[5,4,3,2,1,0],[0,1,2,3,4,5]]))
>>>b
Out[9]:
tensor([[6., 5., 4., 3., 2., 1.],
[0., 1., 2., 3., 4., 5.]])