功能:根据索引来对高维tensor进行选择
要求:
- input tensor 与 index 的 dim一致
- index.shape < input.shape
torch.gather(input, dim, index) → Tensor
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
import torch
# [[1,2,3],
# [4,5,6],
# [7,8,9]]
input = torch.range(1, 9).view(3, 3)
示例1
#====================== dim=0 索引======================#
# 表示提供采样索引维度的是dim=0
dim = 0
# index的维度为 [1, 3], 说明输出维度也为 [1, 3], index每个元素的索引为 [[(0, 0), (0, 1), (0, 2)]]
index = torch.tensor([[2, 1, 0]])
# 用index 来替换index的索引列表中的dim=0 得到: [[(2, 0), (1, 1), (0, 2)]]
output = torch.gather(input, dim, index)
# 将input索引为 (2, 0), (1, 1), (0, 3) 取出来就是 [[7, 5, 3]]
示例2
#======================== dim=1 索引=====================#
# 表示提供的采样索引维度dim=1
dim = 1
# index的维度为 (1, 3), 也就是index每个元素的索引为 [[(0, 0), (0, 1), (0, 2)]],
index = torch.tensor([[2, 1, 0]])
# 用index的取值来替代 index的索引列表中的dim=1的元素得:[[(0,2) (0,1) (0,0)]]
# 将input索引为 (0,2) (0,1) (0,0) 取出来就是[[3, 2, 1]]
optput = torch.gather(input, dim, index)
示例3
#=============================================#
dim = 1 # 表示采样索引为1
index = torch.tensor([[2],
[1],
[0]])
# index的索引为[(0, 0),
(1, 0),
(2, 0)]
# 使用index的其余维索引来补全后得到:
'''
[(0, 2),
(1, 1),
(2, 0)]
'''
# 对input索引
output = torch.gather(input, dim, index)
'''
[[3],
[5],
[7]]
'''
示例4
#====================== 多维index =====================#
dim = 1
index = torch.tensor([[0, 2],
[1, 2]])
# index 的索引为 [[(0,0), (0,1)],
# [(1,0), (1,1)]]
# 用除了1维以外的索引将index补全得:
# [[(0,0), (0,2)],
# [(1,1), (1,2)]]
# 对input索引
output = torch.gather(input, dim, index)
#[[0, 3]
# [5, 6]]
更高维度的gather索引也是如此,先生成index每个元素的索引,再用index的值来替代dim维度的索引值,最后按照索引值到input中索引得到output