作用:
收集输入的特定维度指定位置的数值
函数定义:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
参数:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
理解例子:
一个二维矩阵:
input_tensor=[
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
现在有需求,如何取每行的最后一个非零元素呢? (如果说用循环的同学,可以离开了)
函数使用:
torch.gather(input= input_tensor , dim = 1,index=[[4],[3],[5],[1]])
- dim=1是因为每行的最后一个非零元素是一个列!
- index之所以维度是4*1,是为了满足index维度和output维度之间相等的关系。index的元素可以发现,是每行所要取的数的下标索引(从1开始索引而不是0)
- 结果就是[[5,3,7,1]]