torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor函数
Gathers values along an axis specified by dim.
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
out (Tensor, optional) – the destination tensor
注意这里index的类型是LongTensor, 然后这个函数看了半天.
搞明白索引谁很重要
对于二维张量(m, n), 如果dim=1, 那么输出的tensor shape就是(m, X), index的shape就必须是(m, X), 其中X≤n. index具体每一个值≤m.
b = torch.rand(2,3)
print(b ,'\n')
index1 = torch.LongTensor([[0,1,1], [2,0,1]]) # 每一个具体值≤3
print(torch.gather(b, dim=1, index=index1))
tensor([[0.2017, 0.2936, 0.3220],
[0.7503, 0.6031, 0.2519]])
tensor([[0.2017, 0.2936, 0.2936],
[0.2519, 0.7503, 0.6031]])
如果dim=0, 那么输出的tensor就是 (X, n); index的shape必须是(X, n), 其中 X≤m, index具体每一个值≤n.
index2 = torch.LongTensor([[0,1,0]]) # 每一个具体值≤2
print(torch.gather(b, dim=0, index=index2))
tensor([[0.7193, 0.9087, 0.6696]])