import torch
input = [
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
torch.gather(torch.tensor(input),1,torch.tensor([[3],[2],[4],[0]]))
注意点 input 和index 需要转成tensor
Parameters
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather