torch.gather(input, dim, index) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d). out will have the same shape as index.
- input (Tensor) :the source tensor
- dim (int) :the axis along which to index
- index (LongTensor) :the indices of elements to gather
Example:
# x为RNN的输出,output_dims为每个时间步的输出维度,需要按照index从中取出特定索引的预测值
# [batch_size, seq_len, output_dims]
x = torch.tensor([[[0.1, 0.7, 0.1, 0.1], [0.2, 0.6, 0.1, 0.1],[0.1, 0.5, 0.2, 0.2]],
[[0.2, 0.4, 0.2, 0.2], [0.3, 0.3, 0.2, 0.2],[0.1, 0.4, 0.2, 0.3]],
[[0.1, 0.6, 0.1, 0.2], [0.5, 0.3, 0.1, 0.1],[0.1, 0.1, 0.2, 0.6]]])
# [batch_size, seq_len]
index = torch.tensor([[0, 1, 2], [2, 1, 1], [0, 0, 0]])
index = index.unsqueeze(-1) # 增加一个维度
pred = torch.gather(x, dim=2, index=index)
# 输出结果pred
tensor([[[0.1],[0.6],[0.2]],
[[0.2],[0.3],[0.4]],
[[0.1],[0.5],[0.1]]])