torch.gather()
torch中API如下讲解:
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
if input
is an n-dimensional tensor with size
(
x
0
,
x
1
.
.
.
,
x
i
−
1
,
x
i
,
x
i
+
1
,
.
.
.
,
x
n
−
1
)
(x_{0}, x_{1}...,x_{i-1}, x_{i}, x_{i+1},...,x_{n-1})
(x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim = i
then index
must be an n-dimensional tensor with size
(
x
0
,
x
1
.
.
.
,
x
i
−
1
,
y
,
x
i
+
1
,
.
.
.
,
x
n
−
1
)
(x_{0}, x_{1}...,x_{i-1}, y, x_{i+1},...,x_{n-1})
(x0,x1...,xi−1,y,xi+1,...,xn−1) where
y
≥
1
y\geq 1
y≥1 and out
will have the save size as index
.
Example
以2-D Tensor为例:
input = torch.Tensor([[1,2,3], [4,5,6]])
index = torch.LongTensor([[0,1], [2,0]])
torch.gather(input, dim=1, index)
上述,dim = 1
, 输入向量size(2,3)
, 要求的index的size必须要满足index.size(2, y)
且
y
≥
1
y\geq1
y≥1 。
按照规则,该函数做如下操作:
out[0][0] = input[0][index[0][0]] = 1
out[0][1] = input[0][index[0][1]] = 2
out[1][0] = input[1][index[1][0]] = 6
out[1][1] = input[1][index[1][1]] = 4
因此,输出:
>>> [[1, 2], [6, 4]]