官方文档对该函数的解释:
按自己的理解翻译的,如有错误望指出
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
作用:沿 dim 指定的轴收集值
参数:
input (Tensor) – 要操作的张量
dim (int) – 要索引的轴
index (LongTensor) – 要收集的元素的索引
out (Tensor, optional) – 目标张量-要收集数据得到的张量
sparse_grad (bool,optional) – 如果为真,梯度 w.r.t 输入为稀疏张量
例子:
t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])