torch.gather()
的官方解释是: 通过dim
沿着坐标轴聚集值。
上面那句话太简短了,读完了也不知道这个函数到底怎么用。先看一下它的全部样子:
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
input (Tensor)
– 源张量(source tensor)dim (python:int)
– 用来寻址的坐标轴index (LongTensor)
– 需要被聚集的元素的索引out (Tensor, optional)
–最终输出的tensor,这个一般不会用到sparse_grad (bool,optional)
– 如果为真, 关于输入的梯度将是一个稀疏张量。这个一般也用不到。
现在我们来看看具体这么用,看下面这个例子就一目了然了。
- dim =1
import torch
x = torch.tensor([[1