pytorch: gather函数,index_fill函数
torch.gather(input, dim, index, out=None) → Tensor
In [28]: a=torch.arange(0,16).view(4,4)
In [29]: a
Out[29]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
In [30]: index=torch.LongTensor([[0,1,2,3]])
In [31]: torch.gather(a, 0, index)
Out[31]: tensor([[ 0, 5, 10, 15]])
In [44]: index=torch.LongTensor([[0],[1],[1],[2]])
In [45]: a.gather(1,index)
Out[45]:
tensor([[ 0],
[ 5],
[ 9],
[14]])
**函数的作用:沿给定轴dim,将输入索引张量index指定位置的值进行聚合。这里强调下,index.dim()==input.dim()必须相等,**比如2D,3D,这样index最后一维的值直接就索引到input指定轴里的数据了。通俗的讲就是你给我指定好位置,我去对应位置那数据。
index_fill(dim, index, val)
In [28]: a=torch.arange(0,16).view(4,4)
In [29]: a
Out[29]:
tensor([[ 0